From e472daa4f6211caa94b29824c942994058c37bc5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 20 May 2025 16:26:20 +0800 Subject: [PATCH] feat: Integrates Flask-Login for EndUser management Signed-off-by: -LAN- --- api/core/tools/workflow_as_tool/tool.py | 20 ++++---------------- api/extensions/ext_login.py | 19 ++++++++++++------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 241b4a94de..57c93d1d45 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,9 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast + +from flask_login import current_user from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool @@ -87,7 +89,7 @@ class WorkflowTool(Tool): result = generator.generate( app_model=app, workflow=workflow, - user=self._get_user(user_id), + user=cast("Account | EndUser", current_user), args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, streaming=False, @@ -111,20 +113,6 @@ class WorkflowTool(Tool): yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) - def _get_user(self, user_id: str) -> Union[EndUser, Account]: - """ - get the user by user id - """ - - user = db.session.query(EndUser).filter(EndUser.id == user_id).first() - if not user: - user = db.session.query(Account).filter(Account.id == user_id).first() - - if not user: - raise ValueError("user not found") - - return user - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 10fb89eb73..12a5874553 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -7,7 +7,10 @@ from werkzeug.exceptions import Unauthorized import contexts from dify_app import DifyApp +from extensions.ext_database import db from libs.passport import PassportService +from models.account import Account +from models.model import EndUser from services.account_service import AccountService login_manager = flask_login.LoginManager() @@ -17,10 +20,8 @@ login_manager = flask_login.LoginManager() @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get("Authorization", "") + # Check if the user_id contains a dot, indicating the old format if not auth_header: auth_token = request.args.get("_token") if not auth_token: @@ -34,17 +35,21 @@ def load_user_from_request(request_from_flask_login): raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") + if request.blueprint in {"console", "inner_api"}: + user_id = decoded.get("user_id") - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - return logged_in_account + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + return logged_in_account + else: + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + return end_user @user_logged_in.connect @user_loaded_from_request.connect def on_user_logged_in(_sender, user): """Called when a user logged in.""" - if user: + if user and isinstance(user, Account) and user.current_tenant_id: contexts.tenant_id.set(user.current_tenant_id)