fix: mypy error

pull/22036/head
Novice 12 months ago
parent a9e73653a8
commit ecd18b70a1

@ -1,6 +1,6 @@
from amqp import NotFound
from flask_restful import Resource, reqparse
from pydantic import ValidationError
from werkzeug.exceptions import NotFound
from controllers.mcp import api
from controllers.web.error import (

@ -123,16 +123,14 @@ def create_ssrf_proxy_mcp_http_client(
Returns:
Configured httpx.Client with proxy settings
"""
client_kwargs = {
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
"follow_redirects": True, # Enable redirect following for MCP connections
}
if dify_config.SSRF_PROXY_ALL_URL:
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
return httpx.Client(**client_kwargs)
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
@ -140,10 +138,20 @@ def create_ssrf_proxy_mcp_http_client(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
client_kwargs["mounts"] = proxy_mounts
return httpx.Client(**client_kwargs)
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(**client_kwargs)
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

@ -24,8 +24,8 @@ def generate_pkce_challenge() -> tuple[str, str]:
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8")
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge
@ -213,12 +213,12 @@ def auth(
provider.save_tokens(tokens)
return {"result": "success"}
tokens = provider.tokens()
provider_tokens = provider.tokens()
# Handle token refresh or new authorization
if tokens and tokens.refresh_token:
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, tokens.refresh_token)
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:

@ -90,4 +90,4 @@ class OAuthClientProvider:
if not mcp_provider:
return ""
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
return credentials.get("code_verifier", "")
return str(credentials.get("code_verifier", ""))

@ -3,7 +3,7 @@ import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any
from typing import Any, cast
from urllib.parse import urljoin, urlparse
import httpx
@ -38,9 +38,9 @@ def sse_client(
if headers is None:
headers = {}
read_queue = queue.Queue()
write_queue = queue.Queue()
status_queue = queue.Queue()
read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue()
with ThreadPoolExecutor() as executor:
try:
@ -97,6 +97,9 @@ def sse_client(
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
response = client.post(
endpoint_url,
json=message.message.model_dump(
@ -119,13 +122,14 @@ def sse_client(
executor.submit(sse_reader, status_queue)
try:
status, endpoint_url = status_queue.get(timeout=1)
status, endpoint_url_or_error = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error":
raise endpoint_url
if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url_or_error
endpoint_url = cast(str, endpoint_url_or_error)
executor.submit(post_writer, endpoint_url)
yield read_queue, write_queue

@ -431,8 +431,8 @@ def streamablehttp_client(
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
# Create queues with clear directional meaning
server_to_client_queue = queue.Queue() # For messages FROM server TO client
client_to_server_queue = queue.Queue() # For messages FROM client TO server
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:

@ -1,7 +1,8 @@
import logging
from collections.abc import Callable
from contextlib import ExitStack
from typing import Optional, cast
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any, Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
@ -39,8 +40,8 @@ class MCPClient:
# Initialize session and client objects
self._session: Optional[ClientSession] = None
self._streams_context = None
self._session_context = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack()
# Whether the client has been initialized
@ -51,14 +52,19 @@ class MCPClient:
self._initialized = True
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
):
self.cleanup()
def _initialize(
self,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
"mcp": streamablehttp_client,
"sse": sse_client,
}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
@ -72,7 +78,9 @@ class MCPClient:
except MCPConnectionError:
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth
try:
@ -82,6 +90,9 @@ class MCPClient:
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if self._streams_context is None:
raise MCPConnectionError("Failed to create connection context")
if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
@ -123,8 +134,8 @@ class MCPClient:
def cleanup(self):
"""Clean up resources"""
try:
if self._session:
self._session.__exit__(None, None, None)
if self._session_context:
self._session_context.__exit__(None, None, None)
if self._streams_context:
self._streams_context.__exit__(None, None, None)

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping
from typing import cast
from typing import Any, cast
from configs import dify_config
from controllers.web.passport import generate_session_id
@ -153,7 +153,7 @@ class MCPServerReuqestHandler:
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {}
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
parameters[item.variable] = {}

@ -38,7 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@ -57,6 +57,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
3. Cleanup of in-flight requests
"""
request: ReceiveRequestT
_session: Any
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__(
self,
request_id: RequestId,
@ -146,6 +150,8 @@ class BaseSession(
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT]
_receive_notification_type: type[ReceiveNotificationT]
def __init__(
self,
@ -165,7 +171,6 @@ class BaseSession(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
self._futures = []
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
@ -198,7 +203,7 @@ class BaseSession(
request_id = self._request_id
self._request_id = request_id + 1
response_queue = queue.Queue()
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
self._response_streams[request_id] = response_queue
try:
@ -211,9 +216,9 @@ class BaseSession(
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds()
timeout = float(request_read_timeout_seconds.total_seconds())
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
timeout = float(self._session_read_timeout_seconds.total_seconds())
while True:
try:
response_or_error = response_queue.get(timeout=timeout)

@ -340,8 +340,8 @@ class ClientSession(
case types.ListRootsRequest():
with responder:
response = self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
list_roots_response = self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(list_roots_response)
responder.respond(client_response)
case types.PingRequest():

@ -689,8 +689,8 @@ class ToolManager:
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id)
for provider in mcp_providers:
result_providers[f"mcp_provider.{provider.name}"] = provider
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))

@ -1467,7 +1467,7 @@ class AppMCPServer(Base):
@property
def parameters_dict(self) -> dict[str, Any]:
return json.loads(self.parameters)
return cast(dict[str, Any], json.loads(self.parameters))
class Site(Base):

@ -248,9 +248,8 @@ class MCPToolProvider(Base):
return [Tool(**tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> str:
icon_dict = json.loads(self.icon)
return icon_dict
def provider_icon(self) -> dict[str, str]:
return cast(dict[str, str], json.loads(self.icon))
class ToolModelInvoke(Base):

Loading…
Cancel
Save