pull/21891/head
ytqh 1 year ago
parent c68c674413
commit b33f8fbecb

@ -1,9 +1,9 @@
from flask import Blueprint from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service") bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service")
api = ExternalApi(bp) api = ExternalApi(bp)
from .app import app, audio, completion, conversation, file, message, workflow
from .auth import login from .auth import login
from .user import profile from .user import profile

@ -4,19 +4,18 @@ from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
from extensions.ext_database import db
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in # type: ignore from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
from pydantic import BaseModel # type: ignore
from services.feature_service import FeatureService from services.feature_service import FeatureService
from sqlalchemy import select, update # type: ignore
from sqlalchemy.orm import Session # type: ignore
from werkzeug.exceptions import Forbidden, Unauthorized
class WhereisUserArg(Enum): class WhereisUserArg(Enum):
@ -35,7 +34,9 @@ class FetchUserArg(BaseModel):
# TODO: add auth jwt token check # TODO: add auth jwt token check
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): def validate_app_token(
view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None
):
def decorator(view_func): def decorator(view_func):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):
@ -51,7 +52,11 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api: if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.") raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() tenant = (
db.session.query(Tenant)
.filter(Tenant.id == app_model.tenant_id)
.first()
)
if tenant is None: if tenant is None:
raise ValueError("Tenant does not exist.") raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
@ -76,7 +81,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if user_id: if user_id:
user_id = str(user_id) user_id = str(user_id)
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) kwargs["end_user"] = create_or_update_end_user_for_user_id(
app_model, user_id
)
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
@ -101,13 +108,27 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
documents_upload_quota = features.documents_upload_quota documents_upload_quota = features.documents_upload_quota
if resource == "members" and 0 < members.limit <= members.size: if resource == "members" and 0 < members.limit <= members.size:
raise Forbidden("The number of members has reached the limit of your subscription.") raise Forbidden(
"The number of members has reached the limit of your subscription."
)
elif resource == "apps" and 0 < apps.limit <= apps.size: elif resource == "apps" and 0 < apps.limit <= apps.size:
raise Forbidden("The number of apps has reached the limit of your subscription.") raise Forbidden(
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: "The number of apps has reached the limit of your subscription."
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.") )
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: elif (
raise Forbidden("The number of documents has reached the limit of your subscription.") resource == "vector_space"
and 0 < vector_space.limit <= vector_space.size
):
raise Forbidden(
"The capacity of the vector space has reached the limit of your subscription."
)
elif (
resource == "documents"
and 0 < documents_upload_quota.limit <= documents_upload_quota.size
):
raise Forbidden(
"The number of documents has reached the limit of your subscription."
)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -183,7 +204,9 @@ def validate_and_get_api_token(scope: str | None = None):
""" """
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header: if auth_header is None or " " not in auth_header:
raise Unauthorized("Authorization header must be provided and start with 'Bearer'") raise Unauthorized(
"Authorization header must be provided and start with 'Bearer'"
)
auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
@ -198,7 +221,10 @@ def validate_and_get_api_token(scope: str | None = None):
update(ApiToken) update(ApiToken)
.where( .where(
ApiToken.token == auth_token, ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)), (
ApiToken.last_used_at.is_(None)
| (ApiToken.last_used_at < cutoff_time)
),
ApiToken.type == scope, ApiToken.type == scope,
) )
.values(last_used_at=current_time) .values(last_used_at=current_time)
@ -208,7 +234,9 @@ def validate_and_get_api_token(scope: str | None = None):
api_token = result.scalar_one_or_none() api_token = result.scalar_one_or_none()
if not api_token: if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) stmt = select(ApiToken).where(
ApiToken.token == auth_token, ApiToken.type == scope
)
api_token = session.scalar(stmt) api_token = session.scalar(stmt)
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") raise Unauthorized("Access token is invalid")
@ -218,7 +246,9 @@ def validate_and_get_api_token(scope: str | None = None):
return api_token return api_token
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: def create_or_update_end_user_for_user_id(
app_model: App, user_id: Optional[str] = None
) -> EndUser:
""" """
Create or update session terminal based on user ID. Create or update session terminal based on user ID.
""" """

Loading…
Cancel
Save