diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index d4dbe5836d..27727c2f3f 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -3,7 +3,7 @@ import json import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in -from werkzeug.exceptions import Unauthorized +from werkzeug.exceptions import NotFound, Unauthorized import contexts from dify_app import DifyApp @@ -21,26 +21,32 @@ login_manager = flask_login.LoginManager() def load_user_from_request(request_from_flask_login): """Load user based on the request.""" auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: + if auth_header: if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme, auth_token = auth_header.split(maxsplit=1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + else: + auth_token = request.args.get("_token") - decoded = PassportService().verify(auth_token) if request.blueprint in {"console", "inner_api"}: + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + decoded = PassportService().verify(auth_token) user_id = decoded.get("user_id") logged_in_account = AccountService.load_logged_in_account(account_id=user_id) return logged_in_account - else: + elif request.blueprint == "web": + decoded = PassportService().verify(auth_token) + end_user_id = decoded.get("end_user_id") + if not end_user_id: + raise Unauthorized("Invalid Authorization token.") end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + if not end_user: + raise NotFound("End user not found.") return end_user