diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 241b4a94de..642664e694 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,6 +3,8 @@ import logging from collections.abc import Generator from typing import Any, Optional, Union, cast +from sqlalchemy.orm import Session + from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -111,20 +113,33 @@ class WorkflowTool(Tool): yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) + def _get_end_user(self, session: Session, user_id: str) -> EndUser | None: + return session.query(EndUser).filter(EndUser.id == user_id).first() + + def _get_account_user(self, session: Session, user_id: str) -> Account | None: + account = session.query(Account).filter(Account.id == user_id).first() + if account: + account.load_and_populate_tenant(session=session, tenant_id=self.runtime.tenant_id) + return account + def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ get the user by user id """ - - user = db.session.query(EndUser).filter(EndUser.id == user_id).first() - if not user: - user = db.session.query(Account).filter(Account.id == user_id).first() - - if not user: + with Session(bind=db.engine, expire_on_commit=False) as session: + user: Account | EndUser | None = self._get_end_user(session=session, user_id=user_id) + # FIXME(QuantumGhost): It seems that for `EndUser`, workflow-as-tool may + # still work incorrectly. + if user: + return user + + user = self._get_account_user(session=session, user_id=user_id) + if user: + user.load_and_populate_tenant(session=session, tenant_id=self.runtime.tenant_id) + return user + # Neither Account nor EndUser is found. This should not happen. raise ValueError("user not found") - return user - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata diff --git a/api/models/account.py b/api/models/account.py index bb6a2a4735..9ace8c59f4 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,10 +4,9 @@ from typing import cast from flask_login import UserMixin # type: ignore from sqlalchemy import func -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column from models.base import Base - from .engine import db from .types import StringUUID @@ -60,6 +59,14 @@ class Account(UserMixin, Base): self._current_tenant = tenant + def load_and_populate_tenant(self, session: Session, tenant_id: str): + tenant = session.query(Tenant).where(Tenant.id == tenant_id).one() + ta = session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() + if ta: + tenant.current_role = ta.role + # FIXME(QuantumGhost): A temporary around for workflow-as-tool. + self._current_tenant = tenant + @property def current_tenant_id(self) -> str | None: return self._current_tenant.id if self._current_tenant else None