From f783ad68e4a492c7e5e80873bfbd2b53090ab5c2 Mon Sep 17 00:00:00 2001 From: Novice Date: Wed, 25 Jun 2025 14:09:19 +0800 Subject: [PATCH] chore(refactor): queries in service and auth components --- .../console/workspace/tool_providers.py | 22 ++- api/core/agent/base_agent_runner.py | 12 +- api/core/mcp/auth/auth_flow.py | 7 +- api/core/mcp/auth/auth_provider.py | 40 ++--- api/core/mcp/mcp_client.py | 3 +- api/core/tools/mcp_tool/provider.py | 8 +- api/models/tools.py | 37 +++++ api/services/tools/mcp_tools_mange_service.py | 152 +++++++----------- api/services/tools/tools_transform_service.py | 4 +- 9 files changed, 136 insertions(+), 149 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index ef0d23e280..a8559e7b9b 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -23,6 +23,7 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.mcp_tools_mange_service import MCPToolManageService from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService +from services.tools.tools_transform_service import ToolTransformService from services.tools.workflow_tools_manage_service import WorkflowToolManageService @@ -693,27 +694,26 @@ class ToolMCPAuthApi(Resource): provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) if not provider: raise ValueError("provider not found") - server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id) try: with MCPClient( - server_url, + provider.decrypted_server_url, provider_id, tenant_id, authed=False, authorization_code=args["authorization_code"], + for_list=True, ): MCPToolManageService.update_mcp_provider_credentials( - tenant_id=tenant_id, - provider_id=provider_id, - credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id), + mcp_provider=provider, + credentials=provider.decrypted_credentials, authed=True, ) return {"result": "success"} except MCPAuthError: - auth_provider = OAuthClientProvider(provider_id, tenant_id) + auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) - return auth(auth_provider, server_url, args["authorization_code"]) + return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"]) class ToolMCPDetailApi(Resource): @@ -722,12 +722,8 @@ class ToolMCPDetailApi(Resource): @account_initialization_required def get(self, provider_id): user = current_user - return jsonable_encoder( - MCPToolManageService.retrieve_mcp_provider( - tenant_id=user.current_tenant_id, - provider_id=provider_id, - ) - ) + provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id) + return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider)) class ToolMCPListAllApi(Resource): diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bb41906f9..0d304de97a 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -258,10 +258,14 @@ class BaseAgentRunner(AppRunner): if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] if parameter.options else [] - prompt_tool.parameters["properties"][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or "", - } + prompt_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) if len(enum) > 0: prompt_tool.parameters["properties"][parameter.name]["enum"] = enum diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index d9917a3fbb..1b6afd0b06 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -98,7 +98,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta full_state_data.code_verifier, full_state_data.redirect_uri, ) - provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id) + provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True) provider.save_tokens(tokens) return full_state_data @@ -275,6 +275,7 @@ def auth( server_url: str, authorization_code: Optional[str] = None, state_param: Optional[str] = None, + for_list: bool = False, ) -> dict[str, str]: """Orchestrates the full auth flow with a server using secure Redis state storage.""" metadata = discover_oauth_metadata(server_url) @@ -337,8 +338,8 @@ def auth( metadata, client_information, provider.redirect_url, - provider.provider_id, - provider.tenant_id, + provider.mcp_provider.id, + provider.mcp_provider.tenant_id, ) provider.save_code_verifier(code_verifier) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index 80e165f10d..09d8924b79 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -7,18 +7,20 @@ from core.mcp.types import ( OAuthClientMetadata, OAuthTokens, ) +from models.tools import MCPToolProvider from services.tools.mcp_tools_mange_service import MCPToolManageService LATEST_PROTOCOL_VERSION = "1.0" class OAuthClientProvider: - provider_id: str - tenant_id: str + mcp_provider: MCPToolProvider - def __init__(self, provider_id: str, tenant_id: str): - self.provider_id = provider_id - self.tenant_id = tenant_id + def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False): + if for_list: + self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) + else: + self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id) @property def redirect_url(self) -> str: @@ -39,12 +41,7 @@ class OAuthClientProvider: def client_information(self) -> Optional[OAuthClientInformation]: """Loads information about this OAuth client.""" - mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) - if not mcp_provider: - return None - client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials( - self.tenant_id, self.provider_id - ).get("client_information", {}) + client_information = self.mcp_provider.decrypted_credentials.get("client_information", {}) if not client_information: return None return OAuthClientInformation.model_validate(client_information) @@ -52,15 +49,13 @@ class OAuthClientProvider: def save_client_information(self, client_information: OAuthClientInformationFull) -> None: """Saves client information after dynamic registration.""" MCPToolManageService.update_mcp_provider_credentials( - self.tenant_id, self.provider_id, {"client_information": client_information.model_dump()} + self.mcp_provider, + {"client_information": client_information.model_dump()}, ) def tokens(self) -> Optional[OAuthTokens]: """Loads any existing OAuth tokens for the current session.""" - mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) - if not mcp_provider: - return None - credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id) + credentials = self.mcp_provider.decrypted_credentials if not credentials: return None return OAuthTokens( @@ -74,20 +69,13 @@ class OAuthClientProvider: """Stores new OAuth tokens for the current session.""" # update mcp provider credentials token_dict = tokens.model_dump() - MCPToolManageService.update_mcp_provider_credentials(self.tenant_id, self.provider_id, token_dict, authed=True) + MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) def save_code_verifier(self, code_verifier: str) -> None: """Saves a PKCE code verifier for the current session.""" - # update mcp provider credentials - MCPToolManageService.update_mcp_provider_credentials( - self.tenant_id, self.provider_id, {"code_verifier": code_verifier} - ) + MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) def code_verifier(self) -> str: """Loads the PKCE code verifier for the current session.""" # get code verifier from mcp provider credentials - mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) - if not mcp_provider: - return "" - credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id) - return str(credentials.get("code_verifier", "")) + return str(self.mcp_provider.decrypted_credentials.get("code_verifier", "")) diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 274f84f027..3a036a0278 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -22,6 +22,7 @@ class MCPClient: tenant_id: str, authed: bool = True, authorization_code: Optional[str] = None, + for_list: bool = False, ): # Initialize info self.provider_id = provider_id @@ -35,7 +36,7 @@ class MCPClient: if authed: from core.mcp.auth.auth_provider import OAuthClientProvider - self.provider = OAuthClientProvider(self.provider_id, self.tenant_id) + self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list) self.token = self.provider.tokens() # Initialize session and client objects diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index e8a850e550..77ae6a70e6 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -40,8 +40,6 @@ class MCPToolProviderController(ToolProviderController): @classmethod def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": - from services.tools.mcp_tools_mange_service import MCPToolManageService - """ from db provider """ @@ -55,7 +53,7 @@ class MCPToolProviderController(ToolProviderController): author=db_provider.user.name if db_provider.user else "Anonymous", name=remote_mcp_tool.name, label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name), - provider=db_provider.id, + provider=db_provider.server_identifier, icon=db_provider.icon, ), parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema), @@ -84,9 +82,9 @@ class MCPToolProviderController(ToolProviderController): credentials_schema=[], tools=tools, ), - provider_id=db_provider.id or "", + provider_id=db_provider.server_identifier or "", tenant_id=db_provider.tenant_id or "", - server_url=MCPToolManageService.get_mcp_provider_server_url(db_provider.tenant_id, db_provider.id), + server_url=db_provider.decrypted_server_url, ) def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: diff --git a/api/models/tools.py b/api/models/tools.py index 1266ff1169..3357d6455a 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,7 @@ import json from datetime import datetime from typing import Any, cast +from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated @@ -8,6 +9,7 @@ from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers +from core.helper import encrypter from core.mcp.types import Tool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -258,6 +260,41 @@ class MCPToolProvider(Base): except json.JSONDecodeError: return file_helpers.get_signed_file_url(self.icon) + @property + def decrypted_server_url(self) -> str: + return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url)) + + @property + def masked_server_url(self) -> str: + def mask_url(url: str, mask_char: str = "*") -> str: + """ + mask the url to a simple string + """ + parsed = urlparse(url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + if parsed.path and parsed.path != "/": + return f"{base_url}/{mask_char * 6}" + else: + return base_url + + return mask_url(self.decrypted_server_url) + + @property + def decrypted_credentials(self) -> dict: + from core.tools.mcp_tool.provider import MCPToolProviderController + from core.tools.utils.configuration import ProviderConfigEncrypter + + provider_controller = MCPToolProviderController._from_db(self) + + tool_configuration = ProviderConfigEncrypter( + tenant_id=self.tenant_id, + config=list(provider_controller.get_credentials_schema()), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.provider_id, + ) + return tool_configuration.decrypt(self.credentials, use_cache=False) + class ToolModelInvoke(Base): """ diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index b24ed897b1..7edb20190a 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -1,7 +1,6 @@ import hashlib import json from datetime import datetime -from urllib.parse import urlparse from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError @@ -18,18 +17,7 @@ from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService - -def mask_url(url: str, mask_char: str = "*"): - """ - mask the url to a simple string - """ - parsed = urlparse(url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - if parsed.path and parsed.path != "/": - return f"{base_url}/{mask_char * 6}" - else: - return base_url +UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]" class MCPToolManageService: @@ -38,15 +26,26 @@ class MCPToolManageService: """ @staticmethod - def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider | None: - return ( + def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: + res = ( db.session.query(MCPToolProvider) - .filter( - MCPToolProvider.id == provider_id, - MCPToolProvider.tenant_id == tenant_id, - ) + .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) + .first() + ) + if not res: + raise ValueError("MCP tool not found") + return res + + @staticmethod + def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider: + res = ( + db.session.query(MCPToolProvider) + .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) .first() ) + if not res: + raise ValueError("MCP tool not found") + return res @staticmethod def create_mcp_provider( @@ -109,11 +108,11 @@ class MCPToolManageService: @classmethod def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id) - if mcp_provider is None: - raise ValueError("MCP tool not found") + try: - with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: + with MCPClient( + mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True + ) as mcp_client: tools = mcp_client.list_tools() except MCPAuthError as e: raise ValueError("Please auth the tool first") @@ -130,25 +129,17 @@ class MCPToolManageService: type=ToolProviderType.MCP, icon=mcp_provider.icon, author=mcp_provider.user.name if mcp_provider.user else "Anonymous", - server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id), + server_url=mcp_provider.masked_server_url, updated_at=int(mcp_provider.updated_at.timestamp()), description=I18nObject(en_US="", zh_Hans=""), label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), plugin_unique_identifier=mcp_provider.server_identifier, ) - @classmethod - def retrieve_mcp_provider(cls, tenant_id: str, provider_id: str): - provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if provider is None: - raise ValueError("MCP tool not found") - return ToolTransformService.mcp_provider_to_user_provider(provider).to_dict() - @classmethod def delete_mcp_tool(cls, tenant_id: str, provider_id: str): mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if mcp_tool is None: - raise ValueError("MCP tool not found") + db.session.delete(mcp_tool) db.session.commit() @@ -165,60 +156,38 @@ class MCPToolManageService: server_identifier: str, ): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if mcp_provider is None: - raise ValueError("MCP tool not found") + mcp_provider.name = name mcp_provider.icon = ( json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon ) mcp_provider.server_identifier = server_identifier - if "[__HIDDEN__]" in server_url: - db.session.commit() - return - encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - mcp_provider.server_url = encrypted_server_url - server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - # if the server url is changed, we need to re-auth the tool - try: + if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: + encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) + mcp_provider.server_url = encrypted_server_url + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() + if server_url_hash != mcp_provider.server_url_hash: - try: - with MCPClient( - server_url, - provider_id, - tenant_id, - authed=False, - ) as mcp_client: - tools = mcp_client.list_tools() - mcp_provider.authed = True - mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) - except MCPAuthError: - mcp_provider.authed = False - mcp_provider.tools = "[]" - mcp_provider.encrypted_credentials = "{}" + cls._re_auth_mcp_provider(mcp_provider, provider_id, tenant_id) mcp_provider.server_url_hash = server_url_hash + try: db.session.commit() except IntegrityError as e: db.session.rollback() - # Check if the error message contains the constraint name - if "unique_mcp_provider_name" in str(e.orig): - # Raise your custom exception - raise ValueError(f"A provider with name '{name}' already exists.") - elif "unique_mcp_provider_server_url" in str(e.orig): - # You can define another custom exception for the other constraint - raise ValueError(f"A provider for server URL '{server_url}' already exists.") + error_msg = str(e.orig) + if "unique_mcp_provider_name" in error_msg: + raise ValueError(f"MCP tool {name} already exists") + elif "unique_mcp_provider_server_url" in error_msg: + raise ValueError(f"MCP tool {server_url} already exists") else: - # Re-raise the original exception if it's not the one you're handling raise @classmethod - def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False): - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if mcp_provider is None: - raise ValueError("MCP tool not found") + def update_mcp_provider_credentials(cls, mcp_provider: MCPToolProvider, credentials: dict, authed: bool = False): provider_controller = MCPToolProviderController._from_db(mcp_provider) tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, + tenant_id=mcp_provider.tenant_id, config=list(provider_controller.get_credentials_schema()), provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.provider_id, @@ -229,27 +198,22 @@ class MCPToolManageService: db.session.commit() @classmethod - def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str): - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if mcp_provider is None: - raise ValueError("MCP tool not found") - provider_controller = MCPToolProviderController._from_db(mcp_provider) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, - ) - return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False) - - @classmethod - def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str): - mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - if mcp_provider is None: - raise ValueError("MCP tool not found") - return encrypter.decrypt_token(tenant_id, mcp_provider.server_url) - - @classmethod - def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str): - server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id) - return mask_url(server_url) + def _re_auth_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str): + """re-auth mcp provider""" + try: + with MCPClient( + mcp_provider.decrypted_server_url, + provider_id, + tenant_id, + authed=False, + for_list=True, + ) as mcp_client: + tools = mcp_client.list_tools() + mcp_provider.authed = True + mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + except MCPAuthError: + mcp_provider.authed = False + mcp_provider.tools = "[]" + + # reset credentials + mcp_provider.encrypted_credentials = "{}" diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1298ddc95e..f3ad123e1c 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -191,8 +191,6 @@ class ToolTransformService: @staticmethod def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity: - from services.tools.mcp_tools_mange_service import MCPToolManageService - return ToolProviderApiEntity( id=db_provider.server_identifier if not for_list else db_provider.id, author=db_provider.user.name if db_provider.user else "Anonymous", @@ -200,7 +198,7 @@ class ToolTransformService: icon=db_provider.provider_icon, type=ToolProviderType.MCP, is_team_authorization=db_provider.authed, - server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id), + server_url=db_provider.masked_server_url, tools=ToolTransformService.mcp_tool_to_user_tool( db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] ),