|
|
|
@ -7,7 +7,10 @@ from werkzeug.exceptions import Unauthorized
|
|
|
|
|
|
|
|
|
|
|
|
import contexts
|
|
|
|
import contexts
|
|
|
|
from dify_app import DifyApp
|
|
|
|
from dify_app import DifyApp
|
|
|
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
from libs.passport import PassportService
|
|
|
|
from libs.passport import PassportService
|
|
|
|
|
|
|
|
from models.account import Account
|
|
|
|
|
|
|
|
from models.model import EndUser
|
|
|
|
from services.account_service import AccountService
|
|
|
|
from services.account_service import AccountService
|
|
|
|
|
|
|
|
|
|
|
|
login_manager = flask_login.LoginManager()
|
|
|
|
login_manager = flask_login.LoginManager()
|
|
|
|
@ -17,10 +20,8 @@ login_manager = flask_login.LoginManager()
|
|
|
|
@login_manager.request_loader
|
|
|
|
@login_manager.request_loader
|
|
|
|
def load_user_from_request(request_from_flask_login):
|
|
|
|
def load_user_from_request(request_from_flask_login):
|
|
|
|
"""Load user based on the request."""
|
|
|
|
"""Load user based on the request."""
|
|
|
|
if request.blueprint not in {"console", "inner_api"}:
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Check if the user_id contains a dot, indicating the old format
|
|
|
|
|
|
|
|
auth_header = request.headers.get("Authorization", "")
|
|
|
|
auth_header = request.headers.get("Authorization", "")
|
|
|
|
|
|
|
|
# Check if the user_id contains a dot, indicating the old format
|
|
|
|
if not auth_header:
|
|
|
|
if not auth_header:
|
|
|
|
auth_token = request.args.get("_token")
|
|
|
|
auth_token = request.args.get("_token")
|
|
|
|
if not auth_token:
|
|
|
|
if not auth_token:
|
|
|
|
@ -34,17 +35,21 @@ def load_user_from_request(request_from_flask_login):
|
|
|
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
|
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
|
|
|
|
|
|
|
|
|
|
decoded = PassportService().verify(auth_token)
|
|
|
|
decoded = PassportService().verify(auth_token)
|
|
|
|
user_id = decoded.get("user_id")
|
|
|
|
if request.blueprint in {"console", "inner_api"}:
|
|
|
|
|
|
|
|
user_id = decoded.get("user_id")
|
|
|
|
|
|
|
|
|
|
|
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
|
|
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
|
|
|
return logged_in_account
|
|
|
|
return logged_in_account
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
|
|
|
|
|
|
|
return end_user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@user_logged_in.connect
|
|
|
|
@user_logged_in.connect
|
|
|
|
@user_loaded_from_request.connect
|
|
|
|
@user_loaded_from_request.connect
|
|
|
|
def on_user_logged_in(_sender, user):
|
|
|
|
def on_user_logged_in(_sender, user):
|
|
|
|
"""Called when a user logged in."""
|
|
|
|
"""Called when a user logged in."""
|
|
|
|
if user:
|
|
|
|
if user and isinstance(user, Account) and user.current_tenant_id:
|
|
|
|
contexts.tenant_id.set(user.current_tenant_id)
|
|
|
|
contexts.tenant_id.set(user.current_tenant_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|