Merge branch 'main' into fix/1.4.3-install-plugins
commit
8bcfa6936e
@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
"""Matrixone vector database configuration."""
|
||||
|
||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
|
||||
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
|
||||
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
|
||||
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
|
||||
MATRIXONE_METRIC: str = Field(
|
||||
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
|
||||
)
|
||||
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
version: str = Field(description="Dify version", default="")
|
||||
|
||||
|
||||
class PyProjectTomlConfig(BaseSettings):
|
||||
"""
|
||||
configs in api/pyproject.toml
|
||||
"""
|
||||
|
||||
project: PyProjectConfig = Field(
|
||||
description="configs in the project section of pyproject.toml",
|
||||
default=PyProjectConfig(),
|
||||
)
|
||||
@ -0,0 +1,102 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=args["description"],
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
@ -0,0 +1,421 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda _: "env"),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||
}
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
@ -0,0 +1,8 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from . import mcp
|
||||
@ -0,0 +1,104 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import api
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPAppApi(Resource):
|
||||
def post(self, server_code):
|
||||
def int_or_str(value):
|
||||
if isinstance(value, (int, str)):
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("jsonrpc", type=str, required=True, location="json")
|
||||
parser.add_argument("method", type=str, required=True, location="json")
|
||||
parser.add_argument("params", type=dict, required=False, location="json")
|
||||
parser.add_argument("id", type=int_or_str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
request_id = args.get("id")
|
||||
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")
|
||||
@ -1 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||
# this logic into a dedicated function. This would encapsulate the implementation
|
||||
# details of how different variable types are identified.
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
|
||||
|
||||
def maybe_file_object(o: Any) -> bool:
|
||||
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from constants import IMAGE_EXTENSIONS
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file_id: the id of UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -1,22 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Any, Any] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
else:
|
||||
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: Any, value: Any) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.capacity:
|
||||
self.cache.popitem(last=False) # pop the first item
|
||||
@ -0,0 +1,380 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Optional, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
if model_schema.support_structure_output:
|
||||
model_parameters = _handle_native_json_schema(
|
||||
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||
|
||||
# handle prompt based schema
|
||||
prompt_messages = _handle_prompt_based_schema(
|
||||
prompt_messages=prompt_messages,
|
||||
structured_output_schema=json_schema,
|
||||
)
|
||||
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
if not isinstance(llm_result.message.content, str):
|
||||
raise OutputParserError(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
try:
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
@ -0,0 +1,342 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
|
||||
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
state_key = secrets.token_urlsafe(32)
|
||||
|
||||
# Store the state data in Redis with expiration
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
|
||||
|
||||
return state_key
|
||||
|
||||
|
||||
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
|
||||
# Get state data from Redis
|
||||
state_data = redis_client.get(redis_key)
|
||||
|
||||
if not state_data:
|
||||
raise ValueError("State parameter has expired or does not exist")
|
||||
|
||||
# Delete the state data from Redis immediately after retrieval to prevent reuse
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
try:
|
||||
# Parse and validate the state data
|
||||
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
||||
|
||||
return oauth_state
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
tokens = exchange_authorization(
|
||||
full_state_data.server_url,
|
||||
full_state_data.metadata,
|
||||
full_state_data.client_information,
|
||||
authorization_code,
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except requests.RequestException as e:
|
||||
if isinstance(e, requests.ConnectionError):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
|
||||
|
||||
def start_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
if metadata:
|
||||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
# Prepare state data with all necessary information
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
server_url=server_url,
|
||||
metadata=metadata,
|
||||
client_information=client_information,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
|
||||
# Store state data in Redis and generate secure state key
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
|
||||
params = {
|
||||
"response_type": response_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"redirect_uri": redirect_url,
|
||||
"state": state_key,
|
||||
}
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
authorization_code: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code": authorization_code,
|
||||
"code_verifier": code_verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.parse_obj(response.json())
|
||||
|
||||
|
||||
def register_client(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_metadata: OAuthClientMetadata,
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Performs OAuth 2.0 Dynamic Client Registration."""
|
||||
if metadata:
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||
registration_url = metadata.registration_endpoint
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = requests.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if not response.ok:
|
||||
response.raise_for_status()
|
||||
return OAuthClientInformationFull.model_validate(response.json())
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
|
||||
try:
|
||||
# Retrieve state data from Redis using state key
|
||||
full_state_data = _retrieve_redis_state(state_param)
|
||||
|
||||
code_verifier = full_state_data.code_verifier
|
||||
redirect_uri = full_state_data.redirect_uri
|
||||
|
||||
if not code_verifier or not redirect_uri:
|
||||
raise ValueError("Missing code_verifier or redirect_uri in state data")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid state parameter: {e}")
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
@ -0,0 +1,81 @@
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> Optional[OAuthTokens]:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens) -> None:
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str) -> None:
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
@ -0,0 +1,361 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self._endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self._exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> None:
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
||||
Args:
|
||||
endpoint_url: The endpoint URL to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
url_parsed = urlparse(self.url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||
|
||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
|
||||
"""Handle an 'endpoint' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
endpoint_url = urljoin(self.url, sse_data)
|
||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(_StatusError(ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(_StatusReady(endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
||||
"""Handle a 'message' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
read_queue: Queue to put parsed messages.
|
||||
"""
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
sse: The SSE event object.
|
||||
read_queue: Queue for message events.
|
||||
status_queue: Queue for status events.
|
||||
"""
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
self._handle_endpoint_event(sse.data, status_queue)
|
||||
case "message":
|
||||
self._handle_message_event(sse.data, read_queue)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
event_source: The SSE event source.
|
||||
read_queue: Queue to put received messages.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, read_queue, status_queue)
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
|
||||
"""Send a single message to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send to.
|
||||
message: The message to send.
|
||||
"""
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
|
||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
||||
"""Handle writing messages to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send messages to.
|
||||
write_queue: Queue to read messages from.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
|
||||
self._send_message(client, endpoint_url, message)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"Post writer shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
|
||||
"""Wait for the endpoint URL from the status queue.
|
||||
|
||||
Args:
|
||||
status_queue: Queue to read status from.
|
||||
|
||||
Returns:
|
||||
The endpoint URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status._endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status._exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
Args:
|
||||
executor: Thread pool executor.
|
||||
client: HTTP client.
|
||||
event_source: SSE event source.
|
||||
|
||||
Returns:
|
||||
Tuple of (read_queue, write_queue).
|
||||
"""
|
||||
read_queue: ReadQueue = queue.Queue()
|
||||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
# Wait for endpoint URL
|
||||
endpoint_url = self._wait_for_endpoint(status_queue)
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
# Start post writer thread
|
||||
executor.submit(self.post_writer, client, endpoint_url, write_queue)
|
||||
|
||||
return read_queue, write_queue
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_queue, write_queue) for message communication.
|
||||
"""
|
||||
transport = SSETransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
|
||||
"""
|
||||
Send a message to the server using the provided HTTP client.
|
||||
|
||||
Args:
|
||||
http_client: The HTTP client to use for sending
|
||||
endpoint_url: The endpoint URL to send the message to
|
||||
session_message: The message to send
|
||||
"""
|
||||
try:
|
||||
response = http_client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
||||
def read_messages(
|
||||
sse_client: SSEClient,
|
||||
) -> Generator[SessionMessage | Exception, None, None]:
|
||||
"""
|
||||
Read messages from the SSE client.
|
||||
|
||||
Args:
|
||||
sse_client: The SSE client to read from
|
||||
|
||||
Yields:
|
||||
SessionMessage or Exception for each event received
|
||||
"""
|
||||
try:
|
||||
for sse in sse_client.events():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
yield SessionMessage(message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
yield exc
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error reading SSE messages")
|
||||
yield exc
|
||||
@ -0,0 +1,476 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception | None
|
||||
# Queue types with clearer names for their roles
|
||||
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
|
||||
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "Accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.Client
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: timedelta
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], None] | None = None,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
# Put message in queue that goes to client
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc:
|
||||
# Put exception in queue that goes to client
|
||||
server_to_client_queue.put(exc)
|
||||
return False
|
||||
elif sse.event == "ping":
|
||||
logger.debug("Received ping event")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
def handle_get_stream(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
||||
|
||||
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token")
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = response.read()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
server_to_client_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
server_to_client_queue.put(ValueError(error_msg))
|
||||
|
||||
def _send_session_terminated_error(
|
||||
self,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated by server"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
def post_writer(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
client_to_server_queue: ClientToServerQueue,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
start_get_stream: Callable[[], None],
|
||||
) -> None:
|
||||
"""Handle writing requests to the server.
|
||||
|
||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||
Responses are written to the server_to_client_queue.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
break
|
||||
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
self._handle_resumption_request(ctx)
|
||||
else:
|
||||
self._handle_post_request(ctx)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client) -> None:
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
response = client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
ServerToClientQueue, # Queue for receiving messages FROM server
|
||||
ClientToServerQueue, # Queue for sending messages TO server
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- server_to_client_queue: Queue for reading messages FROM the server
|
||||
- client_to_server_queue: Queue for sending messages TO the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
# Create queues with clear directional meaning
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream() -> None:
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue