feat: type

pull/9184/head
Yeuoly 2 years ago
parent db8bf2a85e
commit 279dee485d
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -128,10 +128,6 @@ class BasicProviderConfig(BaseModel):
return mode return mode
raise ValueError(f'invalid mode value {value}') raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
type: Type = Field(..., description="The type of the credentials") type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials") name: str = Field(..., description="The name of the credentials")

@ -26,7 +26,7 @@ class UserToolProvider(BaseModel):
author: str author: str
name: str # identifier name: str # identifier
description: I18nObject description: I18nObject
icon: str icon: str | dict
label: I18nObject # label label: I18nObject # label
type: ToolProviderType type: ToolProviderType
masked_credentials: Optional[dict] = None masked_credentials: Optional[dict] = None

@ -208,8 +208,12 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers: if not db_providers:
return [] return []
app = db_providers.app
if not app:
raise ValueError("can not read app of workflow")
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] self.tools = [self._get_db_provider_tool(db_providers, app)]
return self.tools return self.tools

@ -1,4 +1,5 @@
import json import json
from datetime import datetime
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -13,7 +14,7 @@ from .model import Account, App, Tenant
from .types import StringUUID from .types import StringUUID
class BuiltinToolProvider(db.Model): class BuiltinToolProvider(Base):
""" """
This table stores the tool provider information for built-in tools for each tenant. This table stores the tool provider information for built-in tools for each tenant.
""" """
@ -25,61 +26,22 @@ class BuiltinToolProvider(db.Model):
) )
# id of the tool provider # id of the tool provider
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the tenant # id of the tenant
tenant_id = db.Column(StringUUID, nullable=True) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider # who created this tool provider
user_id = db.Column(StringUUID, nullable=False) user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider # name of the tool provider
provider = db.Column(db.String(40), nullable=False) provider: Mapped[str] = mapped_column(db.String(40), nullable=False)
# credential of the tool provider # credential of the tool provider
encrypted_credentials = db.Column(db.Text, nullable=True) encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:
return json.loads(self.encrypted_credentials) return json.loads(self.encrypted_credentials)
class PublishedAppTool(db.Model):
"""
The table stores the apps published as a tool for each person.
"""
__tablename__ = 'tool_published_apps'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'),
db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool')
)
# id of the tool provider
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the app
app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
# who published this tool
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False)
# query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field
query_description = db.Column(db.Text, nullable=False)
# query name, the name of the query parameter
query_name = db.Column(db.String(40), nullable=False)
# name of the tool provider
tool_name = db.Column(db.String(40), nullable=False)
# author
author = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
@property
def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
class ApiToolProvider(Base): class ApiToolProvider(Base):
""" """
The table stores the api providers. The table stores the api providers.
@ -129,14 +91,14 @@ class ApiToolProvider(Base):
return json.loads(self.credentials_str) return json.loads(self.credentials_str)
@property @property
def user(self) -> Account: def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()
@property @property
def tenant(self) -> Tenant: def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(db.Model): class ToolLabelBinding(Base):
""" """
The table stores the labels for tools. The table stores the labels for tools.
""" """
@ -146,15 +108,15 @@ class ToolLabelBinding(db.Model):
db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'),
) )
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# tool id # tool id
tool_id = db.Column(db.String(64), nullable=False) tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
# tool type # tool type
tool_type = db.Column(db.String(40), nullable=False) tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
# label name # label name
label_name = db.Column(db.String(40), nullable=False) label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
class WorkflowToolProvider(db.Model): class WorkflowToolProvider(Base):
""" """
The table stores the workflow providers. The table stores the workflow providers.
""" """
@ -165,41 +127,37 @@ class WorkflowToolProvider(db.Model):
db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'),
) )
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# name of the workflow provider # name of the workflow provider
name = db.Column(db.String(40), nullable=False) name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# label of the workflow provider # label of the workflow provider
label = db.Column(db.String(255), nullable=False, server_default='') label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='')
# icon # icon
icon = db.Column(db.String(255), nullable=False) icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
# app id of the workflow provider # app id of the workflow provider
app_id = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# version of the workflow provider # version of the workflow provider
version = db.Column(db.String(255), nullable=False, server_default='') version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default='')
# who created this tool # who created this tool
user_id = db.Column(StringUUID, nullable=False) user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id # tenant id
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# description of the provider # description of the provider
description = db.Column(db.Text, nullable=False) description: Mapped[str] = mapped_column(db.Text, nullable=False)
# parameter configuration # parameter configuration
parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default='[]')
# privacy policy # privacy policy
privacy_policy = db.Column(db.String(255), nullable=True, server_default='') privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default='')
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property @property
def schema_type(self) -> ApiProviderSchemaType: def user(self) -> Account | None:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def user(self) -> Account:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()
@property @property
def tenant(self) -> Tenant: def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
@property @property
@ -210,7 +168,7 @@ class WorkflowToolProvider(db.Model):
] ]
@property @property
def app(self) -> App: def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first() return db.session.query(App).filter(App.id == self.app_id).first()
class ToolModelInvoke(db.Model): class ToolModelInvoke(db.Model):

@ -28,10 +28,13 @@ class BuiltinToolManageService:
tools = provider_controller.get_tools() tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfigurationManager( tool_provider_configurations = ToolConfigurationManager(
tenant_id=tenant_id, provider_controller=provider_controller tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
) )
# check if user has added the provider # check if user has added the provider
builtin_provider: BuiltinToolProvider = ( builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
@ -75,7 +78,7 @@ class BuiltinToolManageService:
update builtin tool provider update builtin tool provider
""" """
# get if the provider exists # get if the provider exists
provider: BuiltinToolProvider = ( provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
@ -89,7 +92,13 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials") raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
# get original credentials if exists # get original credentials if exists
if provider is not None: if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@ -132,7 +141,7 @@ class BuiltinToolManageService:
""" """
get builtin tool provider credentials get builtin tool provider credentials
""" """
provider: BuiltinToolProvider = ( provider_obj: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
@ -141,12 +150,17 @@ class BuiltinToolManageService:
.first() .first()
) )
if provider is None: if provider_obj is None:
return {} return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider) provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
credentials = tool_configuration.decrypt_tool_credentials(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials) credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials return credentials
@ -155,7 +169,7 @@ class BuiltinToolManageService:
""" """
delete tool provider delete tool provider
""" """
provider: BuiltinToolProvider = ( provider_obj: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
@ -164,15 +178,20 @@ class BuiltinToolManageService:
.first() .first()
) )
if provider is None: if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}") raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider) db.session.delete(provider_obj)
db.session.commit() db.session.commit()
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name,
)
tool_configuration.delete_tool_credentials_cache() tool_configuration.delete_tool_credentials_cache()
return {"result": "success"} return {"result": "success"}
@ -212,8 +231,8 @@ class BuiltinToolManageService:
try: try:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller, data=provider_controller,
name_func=lambda x: x.identity.name, name_func=lambda x: x.identity.name,
): ):

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Optional, Union from typing import Literal, Optional, Union, overload
from configs import dify_config from configs import dify_config
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class ToolTransformService: class ToolTransformService:
@classmethod @classmethod
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
""" """
get tool provider icon url get tool provider icon url
""" """
@ -35,7 +35,9 @@ class ToolTransformService:
return url_prefix + "builtin/" + provider_name + "/icon" return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
try: try:
return json.loads(icon) if isinstance(icon, str):
return json.loads(icon)
return icon
except: except:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
@ -92,7 +94,8 @@ class ToolTransformService:
# get credentials schema # get credentials schema
schema = provider_controller.get_credentials_schema() schema = provider_controller.get_credentials_schema()
for name, value in schema.items(): for name, value in schema.items():
result.masked_credentials[name] = ProviderConfig.Type.default(value.type) if result.masked_credentials:
result.masked_credentials[name] = ""
# check if the provider need credentials # check if the provider need credentials
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
@ -184,9 +187,14 @@ class ToolTransformService:
""" """
username = "Anonymous" username = "Anonymous"
try: try:
username = db_provider.user.name user = db_provider.user
if not user:
raise ValueError("user not found")
username = user.name
except Exception as e: except Exception as e:
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
# add provider into providers # add provider into providers
credentials = db_provider.credentials credentials = db_provider.credentials
result = UserToolProvider( result = UserToolProvider(
@ -266,9 +274,9 @@ class ToolTransformService:
author=tool.identity.author, author=tool.identity.author,
name=tool.identity.name, name=tool.identity.name,
label=tool.identity.label, label=tool.identity.label,
description=tool.description.human, description=tool.description.human if tool.description else I18nObject(en_US=''),
parameters=current_parameters, parameters=current_parameters,
labels=labels, labels=labels or [],
) )
if isinstance(tool, ApiToolBundle): if isinstance(tool, ApiToolBundle):
return UserTool( return UserTool(
@ -277,5 +285,5 @@ class ToolTransformService:
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters, parameters=tool.parameters,
labels=labels, labels=labels or [],
) )

@ -4,7 +4,7 @@ from datetime import datetime
from sqlalchemy import or_ from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
@ -32,7 +32,7 @@ class WorkflowToolManageService:
description: str, description: str,
parameters: list[dict], parameters: list[dict],
privacy_policy: str = "", privacy_policy: str = "",
labels: list[str] = None, labels: list[str] | None = None,
) -> dict: ) -> dict:
""" """
Create a workflow tool. Create a workflow tool.
@ -62,12 +62,12 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None: if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None: if app is None:
raise ValueError(f"App {workflow_app_id} not found") raise ValueError(f"App {workflow_app_id} not found")
workflow: Workflow = app.workflow workflow: Workflow | None = app.workflow
if workflow is None: if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}") raise ValueError(f"Workflow not found for app {workflow_app_id}")
@ -106,7 +106,7 @@ class WorkflowToolManageService:
description: str, description: str,
parameters: list[dict], parameters: list[dict],
privacy_policy: str = "", privacy_policy: str = "",
labels: list[str] = None, labels: list[str] | None = None,
) -> dict: ) -> dict:
""" """
Update a workflow tool. Update a workflow tool.
@ -138,7 +138,7 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None: if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists") raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider = ( workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
@ -147,14 +147,14 @@ class WorkflowToolManageService:
if workflow_tool_provider is None: if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found") raise ValueError(f"Tool {workflow_tool_id} not found")
app: App = ( app: App | None = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
) )
if app is None: if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found") raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow = app.workflow workflow: Workflow | None = app.workflow
if workflow is None: if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
@ -243,36 +243,12 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id :param workflow_app_id: the workflow app id
:return: the tool :return: the tool
""" """
db_tool: WorkflowToolProvider = ( db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
) )
return cls._get_workflow_tool(db_tool)
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod @classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
@ -283,19 +259,31 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id :param workflow_app_id: the workflow app id
:return: the tool :return: the tool
""" """
db_tool: WorkflowToolProvider = ( db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first() .first()
) )
return cls._get_workflow_tool(db_tool)
@classmethod
def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool
:return: the tool
"""
if db_tool is None: if db_tool is None:
raise ValueError(f"Tool {workflow_app_id} not found") raise ValueError("Tool not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() workflow_app: App | None = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
if workflow_app is None: if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found") raise ValueError(f"App {db_tool.app_id} not found")
workflow = workflow_app.workflow
if not workflow:
raise ValueError("Workflow not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool) tool = ToolTransformService.workflow_provider_to_controller(db_tool)
@ -308,14 +296,14 @@ class WorkflowToolManageService:
"description": db_tool.description, "description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations), "parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool( "tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
), ),
"synced": workflow_app.workflow.version == db_tool.version, "synced": workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy, "privacy_policy": db_tool.privacy_policy,
} }
@classmethod @classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
""" """
List workflow tool provider tools. List workflow tool provider tools.
:param user_id: the user id :param user_id: the user id
@ -323,7 +311,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id :param workflow_app_id: the workflow app id
:return: the list of tools :return: the list of tools
""" """
db_tool: WorkflowToolProvider = ( db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
@ -336,6 +324,7 @@ class WorkflowToolManageService:
return [ return [
ToolTransformService.tool_to_user_tool( ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
) )
] ]

Loading…
Cancel
Save