fix: invoke tool streamingly

pull/9184/head
Yeuoly 2 years ago
parent cf4e9f317e
commit 886a160115
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -4,8 +4,8 @@ from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
from models.provider import ProviderQuotaType from models.provider import ProviderQuotaType
@ -143,7 +143,7 @@ class ProviderConfig(BasicProviderConfig):
value: str = Field(..., description="The value of the option") value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option") label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelConfigScope | None scope: AppSelectorScope | ModelConfigScope | None = None
required: bool = False required: bool = False
default: Optional[Union[int, str]] = None default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None options: Optional[list[Option]] = None

@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum): class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider" PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache: class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):

@ -1,10 +1,11 @@
from typing import Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool.tool import ToolParameter from core.tools.tool.tool import ToolParameter
@ -14,7 +15,7 @@ class UserTool(BaseModel):
label: I18nObject # label label: I18nObject # label
description: I18nObject description: I18nObject
parameters: Optional[list[ToolParameter]] = None parameters: Optional[list[ToolParameter]] = None
labels: list[str] = None labels: list[str] = Field(default_factory=list)
UserToolProviderTypeLiteral = Optional[Literal[ UserToolProviderTypeLiteral = Optional[Literal[
'builtin', 'api', 'workflow' 'builtin', 'api', 'workflow'
@ -32,8 +33,8 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None original_credentials: Optional[dict] = None
is_team_authorization: bool = False is_team_authorization: bool = False
allow_delete: bool = True allow_delete: bool = True
tools: list[UserTool] = None tools: list[UserTool] = Field(default_factory=list)
labels: list[str] = None labels: list[str] = Field(default_factory=list)
def to_dict(self) -> dict: def to_dict(self) -> dict:
# ------------- # -------------

@ -25,7 +25,7 @@ class ToolLabelEnum(Enum):
UTILITIES = 'utilities' UTILITIES = 'utilities'
OTHER = 'other' OTHER = 'other'
class ToolProviderType(Enum): class ToolProviderType(str, Enum):
""" """
Enum class for tool provider Enum class for tool provider
""" """
@ -181,7 +181,7 @@ class ToolParameter(BaseModel):
if options: if options:
option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
else: else:
option_objs = None option_objs = []
return cls( return cls(
name=name, name=name,
label=I18nObject(en_US='', zh_Hans=''), label=I18nObject(en_US='', zh_Hans=''),

@ -1,21 +1,23 @@
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ProviderConfig,
ToolCredentialsOption,
ToolProviderType, ToolProviderType,
) )
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
class ApiToolProviderController(ToolProviderController): class ApiToolProviderController(ToolProviderController):
provider_id: str provider_id: str
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)
@staticmethod @staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
@ -25,8 +27,8 @@ class ApiToolProviderController(ToolProviderController):
required=True, required=True,
type=ProviderConfig.Type.SELECT, type=ProviderConfig.Type.SELECT,
options=[ options=[
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')), ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
], ],
default='none', default='none',
help=I18nObject( help=I18nObject(
@ -67,9 +69,9 @@ class ApiToolProviderController(ToolProviderController):
zh_Hans='api key header 的前缀' zh_Hans='api key header 的前缀'
), ),
options=[ options=[
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
] ]
) )
} }
@ -96,6 +98,7 @@ class ApiToolProviderController(ToolProviderController):
}, },
'credentials_schema': credentials_schema, 'credentials_schema': credentials_schema,
'provider_id': db_provider.id or '', 'provider_id': db_provider.id or '',
'tenant_id': db_provider.tenant_id or '',
}) })
@property @property
@ -142,7 +145,7 @@ class ApiToolProviderController(ToolProviderController):
return self.tools return self.tools
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: def get_tools(self, tenant_id: str) -> list[ApiTool]:
""" """
fetch tools from database fetch tools from database
@ -153,7 +156,7 @@ class ApiToolProviderController(ToolProviderController):
if self.tools is not None: if self.tools is not None:
return self.tools return self.tools
tools: list[Tool] = [] tools: list[ApiTool] = []
# get tenant api providers # get tenant api providers
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
@ -179,7 +182,7 @@ class ApiToolProviderController(ToolProviderController):
:return: the tool :return: the tool
""" """
if self.tools is None: if self.tools is None:
self.get_tools() self.get_tools(self.tenant_id)
for tool in self.tools: for tool in self.tools:
if tool.identity.name == tool_name: if tool.identity.name == tool_name:

@ -39,7 +39,7 @@ class BuiltinToolProviderController(ToolProviderController):
super().__init__(**{ super().__init__(**{
'identity': provider_yaml['identity'], 'identity': provider_yaml['identity'],
'credentials_schema': provider_yaml.get('credentials_for_provider', None), 'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
}) })
def _get_builtin_tools(self) -> list[BuiltinTool]: def _get_builtin_tools(self) -> list[BuiltinTool]:

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
@ -17,6 +17,8 @@ class ToolProviderController(BaseModel, ABC):
tools: list[Tool] = Field(default_factory=list) tools: list[Tool] = Field(default_factory=list)
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
model_config = ConfigDict(validate_assignment=True)
def get_credentials_schema(self) -> dict[str, ProviderConfig]: def get_credentials_schema(self) -> dict[str, ProviderConfig]:
""" """
returns the credentials schema of the provider returns the credentials schema of the provider

@ -206,7 +206,16 @@ class Tool(BaseModel, ABC):
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
) )
return result if isinstance(result, ToolInvokeMessage):
def single_generator():
yield result
return single_generator()
elif isinstance(result, list):
def generator():
yield from result
return generator()
else:
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
""" """
@ -223,7 +232,7 @@ class Tool(BaseModel, ABC):
return result return result
@abstractmethod @abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
pass pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

@ -116,7 +116,12 @@ class ToolManager:
# decrypt the credentials # decrypt the credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_id) controller = cls.get_builtin_provider(provider_id)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
@ -135,7 +140,12 @@ class ToolManager:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials # decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=api_provider.get_credentials_schema(),
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
@ -513,7 +523,12 @@ class ToolManager:
provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
) )
# init tool configuration # init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)

@ -1,23 +1,25 @@
from collections.abc import Mapping
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ProviderConfig,
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
) )
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
class ToolConfigurationManager(BaseModel): class ToolConfigurationManager(BaseModel):
tenant_id: str tenant_id: str
provider_controller: ToolProviderController config: Mapping[str, BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]: def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
""" """
@ -34,9 +36,9 @@ class ToolConfigurationManager(BaseModel):
credentials = self._deep_copy(credentials) credentials = self._deep_copy(credentials)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.config
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials: if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted credentials[field_name] = encrypted
@ -52,9 +54,9 @@ class ToolConfigurationManager(BaseModel):
credentials = self._deep_copy(credentials) credentials = self._deep_copy(credentials)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.config
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials: if field_name in credentials:
if len(credentials[field_name]) > 6: if len(credentials[field_name]) > 6:
credentials[field_name] = \ credentials[field_name] = \
@ -74,7 +76,7 @@ class ToolConfigurationManager(BaseModel):
""" """
cache = ToolProviderCredentialsCache( cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER cache_type=ToolProviderCredentialsCacheType.PROVIDER
) )
cached_credentials = cache.get() cached_credentials = cache.get()
@ -82,9 +84,9 @@ class ToolConfigurationManager(BaseModel):
return cached_credentials return cached_credentials
credentials = self._deep_copy(credentials) credentials = self._deep_copy(credentials)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema() fields = self.config
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == ProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials: if field_name in credentials:
try: try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
@ -97,7 +99,7 @@ class ToolConfigurationManager(BaseModel):
def delete_tool_credentials_cache(self): def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache( cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER cache_type=ToolProviderCredentialsCacheType.PROVIDER
) )
cache.delete() cache.delete()

@ -16,7 +16,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
class ApiBasedToolSchemaParser: class ApiBasedToolSchemaParser:
@staticmethod @staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
warning = warning if warning is not None else {} warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {} extra_info = extra_info if extra_info is not None else {}
@ -173,7 +173,7 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING return ToolParameter.ToolParameterType.STRING
@staticmethod @staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
""" """
parse openapi yaml to tool bundle parse openapi yaml to tool bundle
@ -189,7 +189,8 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod @staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
""" """
parse swagger to openapi parse swagger to openapi
@ -255,7 +256,7 @@ class ApiBasedToolSchemaParser:
return openapi return openapi
@staticmethod @staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
""" """
parse openapi plugin yaml to tool bundle parse openapi plugin yaml to tool bundle
@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod @staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]:
""" """
auto parse to tool bundle auto parse to tool bundle

@ -1,6 +1,6 @@
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from os import path from os import path
from typing import Any, cast from typing import Any, Iterable, cast
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -158,14 +158,17 @@ class ToolNode(BaseNode):
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
conversation_id=None, conversation_id=None,
) )
result = list(messages)
# extract plain text and files # extract plain text and files
files = self._extract_tool_response_binary(messages) files = self._extract_tool_response_binary(result)
plain_text = self._extract_tool_response_text(messages) plain_text = self._extract_tool_response_text(result)
json = self._extract_tool_response_json(messages) json = self._extract_tool_response_json(result)
return plain_text, files, json return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]: def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
""" """
Extract tool response binary Extract tool response binary
""" """
@ -215,7 +218,7 @@ class ToolNode(BaseNode):
return result return result
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str: def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
""" """
Extract tool response text Extract tool response text
""" """
@ -230,7 +233,7 @@ class ToolNode(BaseNode):
return '\n'.join(result) return '\n'.join(result)
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]: def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = [] result: list[dict] = []
for message in tool_response: for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON: if message.type == ToolInvokeMessage.MessageType.JSON:

@ -7,7 +7,7 @@ from typing import Optional
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float, func, text from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column, relationship
from configs import dify_config from configs import dify_config
from core.file.tool_file_parser import ToolFileParser from core.file.tool_file_parser import ToolFileParser
@ -495,14 +495,14 @@ class InstalledApp(db.Model):
return tenant return tenant
class Conversation(db.Model): class Conversation(Base):
__tablename__ = 'conversations' __tablename__ = 'conversations'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='conversation_pkey'), db.PrimaryKeyConstraint('id', name='conversation_pkey'),
db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
) )
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True) app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
@ -526,8 +526,8 @@ class Conversation(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") messages: Mapped[list["Message"]] = relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") message_annotations: Mapped[list["MessageAnnotation"]] = relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@ -660,10 +660,10 @@ class Message(Base):
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
inputs = db.Column(db.JSON) inputs: Mapped[str] = mapped_column(db.JSON)
query = db.Column(db.Text, nullable=False) query: Mapped[str] = mapped_column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False) message: Mapped[str] = mapped_column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
@ -944,7 +944,7 @@ class MessageFile(Base):
db.Index('message_file_created_by_idx', 'created_by') db.Index('message_file_created_by_idx', 'created_by')
) )
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
@ -956,7 +956,7 @@ class MessageFile(Base):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class MessageAnnotation(db.Model): class MessageAnnotation(Base):
__tablename__ = 'message_annotations' __tablename__ = 'message_annotations'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), db.PrimaryKeyConstraint('id', name='message_annotation_pkey'),
@ -967,7 +967,7 @@ class MessageAnnotation(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(StringUUID, nullable=True) message_id = db.Column(StringUUID, nullable=True)
question = db.Column(db.Text, nullable=True) question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)

@ -77,10 +77,10 @@ class PublishedAppTool(db.Model):
return I18nObject(**json.loads(self.description)) return I18nObject(**json.loads(self.description))
@property @property
def app(self) -> App: def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first() return db.session.query(App).filter(App.id == self.app_id).first()
class ApiToolProvider(db.Model): class ApiToolProvider(Base):
""" """
The table stores the api providers. The table stores the api providers.
""" """
@ -290,7 +290,7 @@ class ToolFile(Base):
db.Index('tool_file_conversation_id_idx', 'conversation_id'), db.Index('tool_file_conversation_id_idx', 'conversation_id'),
) )
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id # conversation user id
user_id: Mapped[str] = mapped_column(StringUUID) user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id # tenant id

@ -3,6 +3,7 @@ import logging
from httpx import get from httpx import get
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -10,8 +11,6 @@ from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ApiProviderSchemaType, ApiProviderSchemaType,
ProviderConfig,
ToolCredentialsOption,
) )
from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
@ -45,8 +44,8 @@ class ApiToolManageService:
required=True, required=True,
default="none", default="none",
options=[ options=[
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="")), ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
], ],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
), ),
@ -79,15 +78,14 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}") raise ValueError(f"invalid schema: {str(e)}")
@staticmethod @staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
""" """
convert schema to tool bundles convert schema to tool bundles
:return: the list of tool bundles, description :return: the list of tool bundles, description
""" """
try: try:
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
return tool_bundles
except Exception as e: except Exception as e:
raise ValueError(f"invalid schema: {str(e)}") raise ValueError(f"invalid schema: {str(e)}")
@ -111,7 +109,7 @@ class ApiToolManageService:
raise ValueError(f"invalid schema type {schema}") raise ValueError(f"invalid schema type {schema}")
# check if the provider exists # check if the provider exists
provider: ApiToolProvider = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -158,7 +156,13 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials # encrypt credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials) db_provider.credentials_str = json.dumps(encrypted_credentials)
@ -195,21 +199,21 @@ class ApiToolManageService:
return {"schema": schema} return {"schema": schema}
@staticmethod @staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
""" """
list api tool provider tools list api tool provider tools
""" """
provider: ApiToolProvider = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider, ApiToolProvider.name == provider_name,
) )
.first() .first()
) )
if provider is None: if provider is None:
raise ValueError(f"you have not added provider {provider}") raise ValueError(f"you have not added provider {provider_name}")
controller = ToolTransformService.api_provider_to_controller(db_provider=provider) controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller) labels = ToolLabelManager.get_tool_labels(controller)
@ -243,7 +247,7 @@ class ApiToolManageService:
raise ValueError(f"invalid schema type {schema}") raise ValueError(f"invalid schema type {schema}")
# check if the provider exists # check if the provider exists
provider: ApiToolProvider = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -282,7 +286,12 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists # get original credentials if exists
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
@ -310,7 +319,7 @@ class ApiToolManageService:
""" """
delete tool provider delete tool provider
""" """
provider: ApiToolProvider = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -360,7 +369,7 @@ class ApiToolManageService:
if tool_bundle is None: if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}") raise ValueError(f"invalid tool name {tool_name}")
db_provider: ApiToolProvider = ( db_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -396,7 +405,12 @@ class ApiToolManageService:
# decrypt credentials # decrypt credentials
if db_provider.id: if db_provider.id:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
@ -444,7 +458,7 @@ class ApiToolManageService:
# add icon # add icon
ToolTransformService.repack_provider(user_provider) ToolTransformService.repack_provider(user_provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) tools = provider_controller.get_tools(tenant_id=tenant_id)
for tool in tools: for tool in tools:
user_provider.tools.append( user_provider.tools.append(

@ -3,12 +3,12 @@ import logging
from typing import Optional, Union from typing import Optional, Union
from configs import dify_config from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ProviderConfig,
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
) )
@ -106,7 +106,10 @@ class ToolTransformService:
# init tool configuration # init tool configuration
tool_configuration = ToolConfigurationManager( tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id, provider_controller=provider_controller tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@ -143,7 +146,7 @@ class ToolTransformService:
@staticmethod @staticmethod
def workflow_provider_to_user_provider( def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController, labels: list[str] = None provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
): ):
""" """
convert provider controller to user provider convert provider controller to user provider
@ -174,7 +177,7 @@ class ToolTransformService:
provider_controller: ApiToolProviderController, provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider, db_provider: ApiToolProvider,
decrypt_credentials: bool = True, decrypt_credentials: bool = True,
labels: list[str] = None, labels: list[str] | None = None,
) -> UserToolProvider: ) -> UserToolProvider:
""" """
convert provider controller to user provider convert provider controller to user provider
@ -209,7 +212,10 @@ class ToolTransformService:
if decrypt_credentials: if decrypt_credentials:
# init tool configuration # init tool configuration
tool_configuration = ToolConfigurationManager( tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id, provider_controller=provider_controller tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
@ -223,9 +229,9 @@ class ToolTransformService:
@staticmethod @staticmethod
def tool_to_user_tool( def tool_to_user_tool(
tool: Union[ApiToolBundle, WorkflowTool, Tool], tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None, credentials: dict | None = None,
tenant_id: str = None, tenant_id: str | None = None,
labels: list[str] = None, labels: list[str] | None = None,
) -> UserTool: ) -> UserTool:
""" """
convert tool to user tool convert tool to user tool

Loading…
Cancel
Save