feat: Integrates Flask-Login for EndUser management

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/19980/head
-LAN- 1 year ago
parent d12de3d532
commit e472daa4f6
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,7 +1,9 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, cast
from flask_login import current_user
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@ -87,7 +89,7 @@ class WorkflowTool(Tool):
result = generator.generate(
app_model=app,
workflow=workflow,
user=self._get_user(user_id),
user=cast("Account | EndUser", current_user),
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
@ -111,20 +113,6 @@ class WorkflowTool(Tool):
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
if not user:
raise ValueError("user not found")
return user
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
"""
fork a new tool with metadata

@ -7,7 +7,10 @@ from werkzeug.exceptions import Unauthorized
import contexts
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from models.account import Account
from models.model import EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
@ -17,10 +20,8 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")
# Check if the user_id contains a dot, indicating the old format
if not auth_header:
auth_token = request.args.get("_token")
if not auth_token:
@ -34,17 +35,21 @@ def load_user_from_request(request_from_flask_login):
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token)
if request.blueprint in {"console", "inner_api"}:
user_id = decoded.get("user_id")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
else:
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
return end_user
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
"""Called when a user logged in."""
if user:
if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id)

Loading…
Cancel
Save