pull/21891/head
ytqh 1 year ago
parent c68c674413
commit b33f8fbecb

@ -1,9 +1,9 @@
from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service")
api = ExternalApi(bp)
from .app import app, audio, completion, conversation, file, message, workflow
from .auth import login
from .user import profile
from .user import profile

@ -4,19 +4,18 @@ from enum import Enum
from functools import wraps
from typing import Optional
from extensions.ext_database import db
from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.model import ApiToken, App, EndUser
from pydantic import BaseModel # type: ignore
from services.feature_service import FeatureService
from sqlalchemy import select, update # type: ignore
from sqlalchemy.orm import Session # type: ignore
from werkzeug.exceptions import Forbidden, Unauthorized
class WhereisUserArg(Enum):
@ -35,7 +34,9 @@ class FetchUserArg(BaseModel):
# TODO: add auth jwt token check
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def validate_app_token(
view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
@ -51,7 +52,11 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
tenant = (
db.session.query(Tenant)
.filter(Tenant.id == app_model.tenant_id)
.first()
)
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
@ -76,7 +81,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if user_id:
user_id = str(user_id)
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
kwargs["end_user"] = create_or_update_end_user_for_user_id(
app_model, user_id
)
return view_func(*args, **kwargs)
@ -101,13 +108,27 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
documents_upload_quota = features.documents_upload_quota
if resource == "members" and 0 < members.limit <= members.size:
raise Forbidden("The number of members has reached the limit of your subscription.")
raise Forbidden(
"The number of members has reached the limit of your subscription."
)
elif resource == "apps" and 0 < apps.limit <= apps.size:
raise Forbidden("The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Forbidden("The number of documents has reached the limit of your subscription.")
raise Forbidden(
"The number of apps has reached the limit of your subscription."
)
elif (
resource == "vector_space"
and 0 < vector_space.limit <= vector_space.size
):
raise Forbidden(
"The capacity of the vector space has reached the limit of your subscription."
)
elif (
resource == "documents"
and 0 < documents_upload_quota.limit <= documents_upload_quota.size
):
raise Forbidden(
"The number of documents has reached the limit of your subscription."
)
else:
return view(*args, **kwargs)
@ -183,7 +204,9 @@ def validate_and_get_api_token(scope: str | None = None):
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
raise Unauthorized(
"Authorization header must be provided and start with 'Bearer'"
)
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
@ -198,7 +221,10 @@ def validate_and_get_api_token(scope: str | None = None):
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
(
ApiToken.last_used_at.is_(None)
| (ApiToken.last_used_at < cutoff_time)
),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
@ -208,7 +234,9 @@ def validate_and_get_api_token(scope: str | None = None):
api_token = result.scalar_one_or_none()
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
stmt = select(ApiToken).where(
ApiToken.token == auth_token, ApiToken.type == scope
)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
@ -218,7 +246,9 @@ def validate_and_get_api_token(scope: str | None = None):
return api_token
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
def create_or_update_end_user_for_user_id(
app_model: App, user_id: Optional[str] = None
) -> EndUser:
"""
Create or update session terminal based on user ID.
"""

Loading…
Cancel
Save