fix: mypy error

pull/22036/head
Novice 12 months ago
parent a9e73653a8
commit ecd18b70a1

@ -1,6 +1,6 @@
from amqp import NotFound
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from pydantic import ValidationError from pydantic import ValidationError
from werkzeug.exceptions import NotFound
from controllers.mcp import api from controllers.mcp import api
from controllers.web.error import ( from controllers.web.error import (

@ -123,16 +123,14 @@ def create_ssrf_proxy_mcp_http_client(
Returns: Returns:
Configured httpx.Client with proxy settings Configured httpx.Client with proxy settings
""" """
client_kwargs = {
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
"follow_redirects": True, # Enable redirect following for MCP connections
}
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL return httpx.Client(
return httpx.Client(**client_kwargs) verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = { proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
@ -140,10 +138,20 @@ def create_ssrf_proxy_mcp_http_client(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
), ),
} }
client_kwargs["mounts"] = proxy_mounts return httpx.Client(
return httpx.Client(**client_kwargs) verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else: else:
return httpx.Client(**client_kwargs) return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

@ -24,8 +24,8 @@ def generate_pkce_challenge() -> tuple[str, str]:
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8") code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_") code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest() code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8") code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_") code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge return code_verifier, code_challenge
@ -213,12 +213,12 @@ def auth(
provider.save_tokens(tokens) provider.save_tokens(tokens)
return {"result": "success"} return {"result": "success"}
tokens = provider.tokens() provider_tokens = provider.tokens()
# Handle token refresh or new authorization # Handle token refresh or new authorization
if tokens and tokens.refresh_token: if provider_tokens and provider_tokens.refresh_token:
try: try:
new_tokens = refresh_authorization(server_url, metadata, client_information, tokens.refresh_token) new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens) provider.save_tokens(new_tokens)
return {"result": "success"} return {"result": "success"}
except Exception as e: except Exception as e:

@ -90,4 +90,4 @@ class OAuthClientProvider:
if not mcp_provider: if not mcp_provider:
return "" return ""
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id) credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
return credentials.get("code_verifier", "") return str(credentials.get("code_verifier", ""))

@ -3,7 +3,7 @@ import queue
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any, cast
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import httpx import httpx
@ -38,9 +38,9 @@ def sse_client(
if headers is None: if headers is None:
headers = {} headers = {}
read_queue = queue.Queue() read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
write_queue = queue.Queue() write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
status_queue = queue.Queue() status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
try: try:
@ -97,6 +97,9 @@ def sse_client(
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None: if message is None:
break break
if isinstance(message, Exception):
write_queue.put(message)
continue
response = client.post( response = client.post(
endpoint_url, endpoint_url,
json=message.message.model_dump( json=message.message.model_dump(
@ -119,13 +122,14 @@ def sse_client(
executor.submit(sse_reader, status_queue) executor.submit(sse_reader, status_queue)
try: try:
status, endpoint_url = status_queue.get(timeout=1) status, endpoint_url_or_error = status_queue.get(timeout=1)
except queue.Empty: except queue.Empty:
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if status != "ready": if status != "ready":
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if status == "error": if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url raise endpoint_url_or_error
endpoint_url = cast(str, endpoint_url_or_error)
executor.submit(post_writer, endpoint_url) executor.submit(post_writer, endpoint_url)
yield read_queue, write_queue yield read_queue, write_queue

@ -431,8 +431,8 @@ def streamablehttp_client(
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
# Create queues with clear directional meaning # Create queues with clear directional meaning
server_to_client_queue = queue.Queue() # For messages FROM server TO client server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue = queue.Queue() # For messages FROM client TO server client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor: with ThreadPoolExecutor(max_workers=2) as executor:
try: try:

@ -1,7 +1,8 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from contextlib import ExitStack from contextlib import AbstractContextManager, ExitStack
from typing import Optional, cast from types import TracebackType
from typing import Any, Optional, cast
from urllib.parse import urlparse from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client from core.mcp.client.sse_client import sse_client
@ -39,8 +40,8 @@ class MCPClient:
# Initialize session and client objects # Initialize session and client objects
self._session: Optional[ClientSession] = None self._session: Optional[ClientSession] = None
self._streams_context = None self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context = None self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack() self.exit_stack = ExitStack()
# Whether the client has been initialized # Whether the client has been initialized
@ -51,14 +52,19 @@ class MCPClient:
self._initialized = True self._initialized = True
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
):
self.cleanup() self.cleanup()
def _initialize( def _initialize(
self, self,
): ):
"""Initialize the client with fallback to SSE if streamable connection fails""" """Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client} connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
"mcp": streamablehttp_client,
"sse": sse_client,
}
parsed_url = urlparse(self.server_url) parsed_url = urlparse(self.server_url)
path = parsed_url.path path = parsed_url.path
@ -72,7 +78,9 @@ class MCPClient:
except MCPConnectionError: except MCPConnectionError:
self.connect_server(streamablehttp_client, "mcp") self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True): def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth
try: try:
@ -82,6 +90,9 @@ class MCPClient:
else {} else {}
) )
self._streams_context = client_factory(url=self.server_url, headers=headers) self._streams_context = client_factory(url=self.server_url, headers=headers)
if self._streams_context is None:
raise MCPConnectionError("Failed to create connection context")
if method_name == "mcp": if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__() read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream) streams = (read_stream, write_stream)
@ -123,8 +134,8 @@ class MCPClient:
def cleanup(self): def cleanup(self):
"""Clean up resources""" """Clean up resources"""
try: try:
if self._session: if self._session_context:
self._session.__exit__(None, None, None) self._session_context.__exit__(None, None, None)
if self._streams_context: if self._streams_context:
self._streams_context.__exit__(None, None, None) self._streams_context.__exit__(None, None, None)

@ -1,6 +1,6 @@
import json import json
from collections.abc import Mapping from collections.abc import Mapping
from typing import cast from typing import Any, cast
from configs import dify_config from configs import dify_config
from controllers.web.passport import generate_session_id from controllers.web.passport import generate_session_id
@ -153,7 +153,7 @@ class MCPServerReuqestHandler:
) )
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {} parameters: dict[str, dict[str, Any]] = {}
required = [] required = []
for item in user_input_form: for item in user_input_form:
parameters[item.variable] = {} parameters[item.variable] = {}

@ -38,7 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1 DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]): class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@ -57,6 +57,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
3. Cleanup of in-flight requests 3. Cleanup of in-flight requests
""" """
request: ReceiveRequestT
_session: Any
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__( def __init__(
self, self,
request_id: RequestId, request_id: RequestId,
@ -146,6 +150,8 @@ class BaseSession(
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]] _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
_request_id: int _request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT]
_receive_notification_type: type[ReceiveNotificationT]
def __init__( def __init__(
self, self,
@ -165,7 +171,6 @@ class BaseSession(
self._session_read_timeout_seconds = read_timeout_seconds self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack() self._exit_stack = ExitStack()
self._futures = []
def __enter__(self) -> Self: def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor() self._executor = ThreadPoolExecutor()
@ -198,7 +203,7 @@ class BaseSession(
request_id = self._request_id request_id = self._request_id
self._request_id = request_id + 1 self._request_id = request_id + 1
response_queue = queue.Queue() response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
self._response_streams[request_id] = response_queue self._response_streams[request_id] = response_queue
try: try:
@ -211,9 +216,9 @@ class BaseSession(
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
timeout = DEFAULT_RESPONSE_READ_TIMEOUT timeout = DEFAULT_RESPONSE_READ_TIMEOUT
if request_read_timeout_seconds is not None: if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds() timeout = float(request_read_timeout_seconds.total_seconds())
elif self._session_read_timeout_seconds is not None: elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds() timeout = float(self._session_read_timeout_seconds.total_seconds())
while True: while True:
try: try:
response_or_error = response_queue.get(timeout=timeout) response_or_error = response_queue.get(timeout=timeout)

@ -340,8 +340,8 @@ class ClientSession(
case types.ListRootsRequest(): case types.ListRootsRequest():
with responder: with responder:
response = self._list_roots_callback(ctx) list_roots_response = self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response) client_response = ClientResponse.validate_python(list_roots_response)
responder.respond(client_response) responder.respond(client_response)
case types.PingRequest(): case types.PingRequest():

@ -689,8 +689,8 @@ class ToolManager:
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters: if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id) mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id)
for provider in mcp_providers: for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{provider.name}"] = provider result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
return BuiltinToolProviderSort.sort(list(result_providers.values())) return BuiltinToolProviderSort.sort(list(result_providers.values()))

@ -1467,7 +1467,7 @@ class AppMCPServer(Base):
@property @property
def parameters_dict(self) -> dict[str, Any]: def parameters_dict(self) -> dict[str, Any]:
return json.loads(self.parameters) return cast(dict[str, Any], json.loads(self.parameters))
class Site(Base): class Site(Base):

@ -248,9 +248,8 @@ class MCPToolProvider(Base):
return [Tool(**tool) for tool in json.loads(self.tools)] return [Tool(**tool) for tool in json.loads(self.tools)]
@property @property
def provider_icon(self) -> str: def provider_icon(self) -> dict[str, str]:
icon_dict = json.loads(self.icon) return cast(dict[str, str], json.loads(self.icon))
return icon_dict
class ToolModelInvoke(Base): class ToolModelInvoke(Base):

Loading…
Cancel
Save