Merge branch 'feat/mcp' into deploy/dev

pull/22338/head^2
Novice 11 months ago
commit 24c5ee1d6d

@ -15,7 +15,7 @@ from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
@ -942,8 +942,14 @@ class ToolMCPAuthApi(Resource):
except MCPAuthError: except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True) auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"]) return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
except MCPError as e:
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials={},
authed=False,
)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
class ToolMCPDetailApi(Resource): class ToolMCPDetailApi(Resource):

@ -284,7 +284,7 @@ def sse_client(
try: try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()

@ -185,7 +185,6 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
self.url, self.url,
2,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client, client=client,
@ -215,7 +214,6 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
self.url, self.url,
2,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client, client=ctx.client,

@ -179,10 +179,7 @@ class BaseSession(
def check_receiver_status(self) -> None: def check_receiver_status(self) -> None:
if self._receiver_future.done(): if self._receiver_future.done():
try:
self._receiver_future.result() self._receiver_future.result()
except Exception as e:
raise e
def __exit__( def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None

@ -6,8 +6,6 @@ from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError from core.mcp.types import ErrorData, JSONRPCError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
STATUS_FORCELIST = [429, 500, 502, 503, 504] STATUS_FORCELIST = [429, 500, 502, 503, 504]
@ -57,7 +55,7 @@ def create_ssrf_proxy_mcp_http_client(
) )
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def ssrf_proxy_sse_connect(url, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection. """Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings This function creates an SSE connection using the configured proxy settings
@ -65,7 +63,6 @@ def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
Args: Args:
url: The SSE endpoint URL url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection **kwargs: Additional arguments passed to the SSE connection
Returns: Returns:

@ -39,14 +39,19 @@ class MCPTool(Tool):
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
from core.tools.errors import ToolInvokeError
try: try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters) tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e: except MCPAuthError as e:
raise ValueError("Please auth the tool first") raise ToolInvokeError("Please auth the tool first") from e
except MCPConnectionError as e: except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}") raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
for content in result.content: for content in result.content:
if isinstance(content, TextContent): if isinstance(content, TextContent):
yield self.create_text_message(content.text) yield self.create_text_message(content.text)

@ -1,13 +1,14 @@
import hashlib import hashlib
import json import json
from datetime import datetime from datetime import datetime
from typing import Any
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -120,7 +121,7 @@ class MCPToolManageService:
tools = mcp_client.list_tools() tools = mcp_client.list_tools()
except MCPAuthError as e: except MCPAuthError as e:
raise ValueError("Please auth the tool first") raise ValueError("Please auth the tool first")
except MCPConnectionError as e: except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}") raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True mcp_provider.authed = True
@ -174,7 +175,7 @@ class MCPToolManageService:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash: if server_url_hash != mcp_provider.server_url_hash:
cls._re_auth_mcp_provider(mcp_provider, provider_id, tenant_id) cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id)
mcp_provider.server_url_hash = server_url_hash mcp_provider.server_url_hash = server_url_hash
try: try:
db.session.commit() db.session.commit()
@ -191,7 +192,9 @@ class MCPToolManageService:
raise raise
@classmethod @classmethod
def update_mcp_provider_credentials(cls, mcp_provider: MCPToolProvider, credentials: dict, authed: bool = False): def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController._from_db(mcp_provider) provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id, tenant_id=mcp_provider.tenant_id,
@ -202,11 +205,13 @@ class MCPToolManageService:
mcp_provider.updated_at = datetime.now() mcp_provider.updated_at = datetime.now()
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials}) mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed mcp_provider.authed = authed
if not authed:
mcp_provider.tools = "[]"
db.session.commit() db.session.commit()
@classmethod @classmethod
def _re_auth_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str): def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
"""re-auth mcp provider""" """re-connect mcp provider"""
try: try:
with MCPClient( with MCPClient(
mcp_provider.decrypted_server_url, mcp_provider.decrypted_server_url,
@ -221,6 +226,7 @@ class MCPToolManageService:
except MCPAuthError: except MCPAuthError:
mcp_provider.authed = False mcp_provider.authed = False
mcp_provider.tools = "[]" mcp_provider.tools = "[]"
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
# reset credentials # reset credentials
mcp_provider.encrypted_credentials = "{}" mcp_provider.encrypted_credentials = "{}"

Loading…
Cancel
Save