feat: Integrates Flask-Login for EndUser management

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/19980/head
-LAN- 1 year ago
parent d12de3d532
commit e472daa4f6
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,7 +1,9 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast from typing import Any, Optional, cast
from flask_login import current_user
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
@ -87,7 +89,7 @@ class WorkflowTool(Tool):
result = generator.generate( result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
user=self._get_user(user_id), user=cast("Account | EndUser", current_user),
args={"inputs": tool_parameters, "files": files}, args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from, invoke_from=self.runtime.invoke_from,
streaming=False, streaming=False,
@ -111,20 +113,6 @@ class WorkflowTool(Tool):
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs) yield self.create_json_message(outputs)
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
if not user:
raise ValueError("user not found")
return user
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
""" """
fork a new tool with metadata fork a new tool with metadata

@ -7,7 +7,10 @@ from werkzeug.exceptions import Unauthorized
import contexts import contexts
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.account import Account
from models.model import EndUser
from services.account_service import AccountService from services.account_service import AccountService
login_manager = flask_login.LoginManager() login_manager = flask_login.LoginManager()
@ -17,10 +20,8 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
# Check if the user_id contains a dot, indicating the old format
if not auth_header: if not auth_header:
auth_token = request.args.get("_token") auth_token = request.args.get("_token")
if not auth_token: if not auth_token:
@ -34,17 +35,21 @@ def load_user_from_request(request_from_flask_login):
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id") if request.blueprint in {"console", "inner_api"}:
user_id = decoded.get("user_id")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id) logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account return logged_in_account
else:
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
return end_user
@user_logged_in.connect @user_logged_in.connect
@user_loaded_from_request.connect @user_loaded_from_request.connect
def on_user_logged_in(_sender, user): def on_user_logged_in(_sender, user):
"""Called when a user logged in.""" """Called when a user logged in."""
if user: if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id) contexts.tenant_id.set(user.current_tenant_id)

Loading…
Cancel
Save