Merge upstream/main - resolved conflicts in answer_node.py
- Combined output variables functionality with upstream ArrayFileSegment changes - Maintained both features: custom outputs and proper file handlingpull/20921/head
commit
3b70c408f9
@ -0,0 +1,28 @@
|
||||
name: Deploy RAG Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
"""Matrixone vector database configuration."""
|
||||
|
||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
|
||||
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
|
||||
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
|
||||
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
|
||||
MATRIXONE_METRIC: str = Field(
|
||||
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
|
||||
)
|
||||
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
version: str = Field(description="Dify version", default="")
|
||||
|
||||
|
||||
class PyProjectTomlConfig(BaseSettings):
|
||||
"""
|
||||
configs in api/pyproject.toml
|
||||
"""
|
||||
|
||||
project: PyProjectConfig = Field(
|
||||
description="configs in the project section of pyproject.toml",
|
||||
default=PyProjectConfig(),
|
||||
)
|
||||
@ -0,0 +1,107 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=args["description"],
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.filter(AppMCPServer.id == server_id)
|
||||
.filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
@ -0,0 +1,421 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda _: "env"),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||
}
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
@ -0,0 +1,8 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from . import mcp
|
||||
@ -0,0 +1,104 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import api
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPAppApi(Resource):
|
||||
def post(self, server_code):
|
||||
def int_or_str(value):
|
||||
if isinstance(value, (int, str)):
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("jsonrpc", type=str, required=True, location="json")
|
||||
parser.add_argument("method", type=str, required=True, location="json")
|
||||
parser.add_argument("params", type=dict, required=False, location="json")
|
||||
parser.add_argument("id", type=int_or_str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
request_id = args.get("id")
|
||||
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")
|
||||
@ -1 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||
# this logic into a dedicated function. This would encapsulate the implementation
|
||||
# details of how different variable types are identified.
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
|
||||
|
||||
def maybe_file_object(o: Any) -> bool:
|
||||
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from constants import IMAGE_EXTENSIONS
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file_id: the id of UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -1,22 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Any, Any] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
else:
|
||||
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: Any, value: Any) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.capacity:
|
||||
self.cache.popitem(last=False) # pop the first item
|
||||
@ -0,0 +1,380 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Optional, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
if model_schema.support_structure_output:
|
||||
model_parameters = _handle_native_json_schema(
|
||||
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||
|
||||
# handle prompt based schema
|
||||
prompt_messages = _handle_prompt_based_schema(
|
||||
prompt_messages=prompt_messages,
|
||||
structured_output_schema=json_schema,
|
||||
)
|
||||
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
if not isinstance(llm_result.message.content, str):
|
||||
raise OutputParserError(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
try:
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue