Merge branch 'fix/explore-tabs-change-failed' into fix/e-300
commit
41f4eb044d
@ -0,0 +1,421 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
DraftWorkflowNotExist,
|
||||||
|
)
|
||||||
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
|
from core.variables.segment_group import SegmentGroup
|
||||||
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
|
from factories.variable_factory import build_segment_with_type
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models import App, AppMode, db
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||||
|
if isinstance(value, FileSegment):
|
||||||
|
return value.value.model_dump()
|
||||||
|
elif isinstance(value, ArrayFileSegment):
|
||||||
|
return [i.model_dump() for i in value.value]
|
||||||
|
elif isinstance(value, SegmentGroup):
|
||||||
|
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||||
|
else:
|
||||||
|
return value.value
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||||
|
value = variable.get_value()
|
||||||
|
# create a copy of the value to avoid affecting the model cache.
|
||||||
|
value = value.model_copy(deep=True)
|
||||||
|
# Refresh the url signature before returning it to client.
|
||||||
|
if isinstance(value, FileSegment):
|
||||||
|
file = value.value
|
||||||
|
file.remote_url = file.generate_url()
|
||||||
|
elif isinstance(value, ArrayFileSegment):
|
||||||
|
files = value.value
|
||||||
|
for file in files:
|
||||||
|
file.remote_url = file.generate_url()
|
||||||
|
return _convert_values_to_json_serializable_object(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pagination_parser():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"page",
|
||||||
|
type=inputs.int_range(1, 100_000),
|
||||||
|
required=False,
|
||||||
|
default=1,
|
||||||
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
|
"value_type": fields.String,
|
||||||
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
|
"visible": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||||
|
value=fields.Raw(attribute=_serialize_var_value),
|
||||||
|
)
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String(attribute=lambda _: "env"),
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
|
"value_type": fields.String,
|
||||||
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
|
"visible": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||||
|
return var_list.variables
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||||
|
"total": fields.Raw(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _api_prerequisite(f):
|
||||||
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
|
It ensures the following conditions are satisfied:
|
||||||
|
|
||||||
|
- Dify has been property setup.
|
||||||
|
- The request user has logged in and initialized.
|
||||||
|
- The requested app is a workflow or a chat flow.
|
||||||
|
- The request user has the edit permission for the app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
parser = _create_pagination_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||||
|
if not workflow_exist:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||||
|
app_id=app_model.id,
|
||||||
|
page=args.page,
|
||||||
|
limit=args.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||||
|
if node_id in [
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
]:
|
||||||
|
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||||
|
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||||
|
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||||
|
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||||
|
#
|
||||||
|
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||||
|
|
||||||
|
raise InvalidArgumentError(
|
||||||
|
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||||
|
|
||||||
|
return node_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
srv = WorkflowDraftVariableService(db.session())
|
||||||
|
srv.delete_node_variables(app_model.id, node_id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableApi(Resource):
|
||||||
|
_PATCH_NAME_FIELD = "name"
|
||||||
|
_PATCH_VALUE_FIELD = "value"
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def get(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def patch(self, app_model: App, variable_id: str):
|
||||||
|
# Request payload for file types:
|
||||||
|
#
|
||||||
|
# Local File:
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "type": "image",
|
||||||
|
# "transfer_method": "local_file",
|
||||||
|
# "url": "",
|
||||||
|
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Remote File:
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "type": "image",
|
||||||
|
# "transfer_method": "remote_url",
|
||||||
|
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||||
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
|
# }
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
|
# Parse 'value' field as-is to maintain its original data structure
|
||||||
|
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
args = parser.parse_args(strict=True)
|
||||||
|
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||||
|
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||||
|
if new_name is None and raw_value is None:
|
||||||
|
return variable
|
||||||
|
|
||||||
|
new_value = None
|
||||||
|
if raw_value is not None:
|
||||||
|
if variable.value_type == SegmentType.FILE:
|
||||||
|
if not isinstance(raw_value, dict):
|
||||||
|
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||||
|
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||||
|
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||||
|
if not isinstance(raw_value, list):
|
||||||
|
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||||
|
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||||
|
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||||
|
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||||
|
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||||
|
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||||
|
db.session.commit()
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
draft_var_srv.delete_variable(variable)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableResetApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def put(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_srv = WorkflowService()
|
||||||
|
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||||
|
if draft_workflow is None:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Draft workflow not found, app_id={app_model.id}",
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||||
|
db.session.commit()
|
||||||
|
if resetted is None:
|
||||||
|
return Response("", 204)
|
||||||
|
else:
|
||||||
|
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||||
|
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||||
|
else:
|
||||||
|
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||||
|
return draft_vars
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||||
|
# so their IDs can be returned to the caller.
|
||||||
|
workflow_srv = WorkflowService()
|
||||||
|
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||||
|
if draft_workflow is None:
|
||||||
|
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||||
|
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||||
|
db.session.commit()
|
||||||
|
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||||
|
if workflow is None:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
env_vars = workflow.environment_variables
|
||||||
|
env_vars_list = []
|
||||||
|
for v in env_vars:
|
||||||
|
env_vars_list.append(
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"type": "env",
|
||||||
|
"name": v.name,
|
||||||
|
"description": v.description,
|
||||||
|
"selector": v.selector,
|
||||||
|
"value_type": v.value_type.value,
|
||||||
|
"value": v.value,
|
||||||
|
# Do not track edited for env vars.
|
||||||
|
"edited": False,
|
||||||
|
"visible": True,
|
||||||
|
"editable": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"items": env_vars_list}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
WorkflowVariableCollectionApi,
|
||||||
|
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||||
|
)
|
||||||
|
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||||
|
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||||
|
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||||
|
|
||||||
|
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||||
|
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||||
|
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||||
@ -1 +1,11 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||||
|
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||||
|
# this logic into a dedicated function. This would encapsulate the implementation
|
||||||
|
# details of how different variable types are identified.
|
||||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_file_object(o: Any) -> bool:
|
||||||
|
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||||
|
|||||||
@ -0,0 +1,374 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
|
from copy import deepcopy
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Literal, Optional, cast, overload
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
|
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
|
from core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
LLMResultChunkWithStructuredOutput,
|
||||||
|
LLMResultWithStructuredOutput,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(StrEnum):
|
||||||
|
"""Constants for model response formats"""
|
||||||
|
|
||||||
|
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||||
|
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||||
|
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialModelType(StrEnum):
|
||||||
|
"""Constants for identifying model types"""
|
||||||
|
|
||||||
|
GEMINI = "gemini"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def invoke_llm_with_structured_output(
|
||||||
|
provider: str,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
prompt_messages: Sequence[PromptMessage],
|
||||||
|
json_schema: Mapping[str, Any],
|
||||||
|
model_parameters: Optional[Mapping] = None,
|
||||||
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: Literal[True] = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def invoke_llm_with_structured_output(
|
||||||
|
provider: str,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
prompt_messages: Sequence[PromptMessage],
|
||||||
|
json_schema: Mapping[str, Any],
|
||||||
|
model_parameters: Optional[Mapping] = None,
|
||||||
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: Literal[False] = False,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
) -> LLMResultWithStructuredOutput: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def invoke_llm_with_structured_output(
|
||||||
|
provider: str,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
prompt_messages: Sequence[PromptMessage],
|
||||||
|
json_schema: Mapping[str, Any],
|
||||||
|
model_parameters: Optional[Mapping] = None,
|
||||||
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_llm_with_structured_output(
|
||||||
|
provider: str,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
prompt_messages: Sequence[PromptMessage],
|
||||||
|
json_schema: Mapping[str, Any],
|
||||||
|
model_parameters: Optional[Mapping] = None,
|
||||||
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke large language model with structured output
|
||||||
|
1. This method invokes model_instance.invoke_llm with json_schema
|
||||||
|
2. Try to parse the result as structured output
|
||||||
|
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param json_schema: json schema
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:param callbacks: callbacks
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
|
||||||
|
# handle native json schema
|
||||||
|
model_parameters_with_json_schema: dict[str, Any] = {
|
||||||
|
**(model_parameters or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_schema.support_structure_output:
|
||||||
|
model_parameters = _handle_native_json_schema(
|
||||||
|
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Set appropriate response format based on model capabilities
|
||||||
|
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||||
|
|
||||||
|
# handle prompt based schema
|
||||||
|
prompt_messages = _handle_prompt_based_schema(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
structured_output_schema=json_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_result = model_instance.invoke_llm(
|
||||||
|
prompt_messages=list(prompt_messages),
|
||||||
|
model_parameters=model_parameters_with_json_schema,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(llm_result, LLMResult):
|
||||||
|
if not isinstance(llm_result.message.content, str):
|
||||||
|
raise OutputParserError(
|
||||||
|
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResultWithStructuredOutput(
|
||||||
|
structured_output=_parse_structured_output(llm_result.message.content),
|
||||||
|
model=llm_result.model,
|
||||||
|
message=llm_result.message,
|
||||||
|
usage=llm_result.usage,
|
||||||
|
system_fingerprint=llm_result.system_fingerprint,
|
||||||
|
prompt_messages=llm_result.prompt_messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||||
|
result_text: str = ""
|
||||||
|
prompt_messages: Sequence[PromptMessage] = []
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
|
for event in llm_result:
|
||||||
|
if isinstance(event, LLMResultChunk):
|
||||||
|
if isinstance(event.delta.message.content, str):
|
||||||
|
result_text += event.delta.message.content
|
||||||
|
prompt_messages = event.prompt_messages
|
||||||
|
system_fingerprint = event.system_fingerprint
|
||||||
|
|
||||||
|
yield LLMResultChunkWithStructuredOutput(
|
||||||
|
model=model_schema.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
|
delta=event.delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield LLMResultChunkWithStructuredOutput(
|
||||||
|
structured_output=_parse_structured_output(result_text),
|
||||||
|
model=model_schema.model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=""),
|
||||||
|
usage=None,
|
||||||
|
finish_reason=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return generator()
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_native_json_schema(
|
||||||
|
provider: str,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
structured_output_schema: Mapping,
|
||||||
|
model_parameters: dict,
|
||||||
|
rules: list[ParameterRule],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Handle structured output for models with native JSON schema support.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
:return: Updated model parameters with JSON schema configuration
|
||||||
|
"""
|
||||||
|
# Process schema according to model requirements
|
||||||
|
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
|
||||||
|
|
||||||
|
# Set JSON schema in parameters
|
||||||
|
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Set appropriate response format if required by the model
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||||
|
|
||||||
|
return model_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
||||||
|
"""
|
||||||
|
Set the appropriate response format parameter based on model rules.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
"""
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format":
|
||||||
|
if ResponseFormat.JSON.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||||
|
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_prompt_based_schema(
|
||||||
|
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||||
|
) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Handle structured output for models without native JSON schema support.
|
||||||
|
This function modifies the prompt messages to include schema-based output requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_messages: Original sequence of prompt messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||||
|
"""
|
||||||
|
# Convert schema to string format
|
||||||
|
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Find existing system prompt with schema placeholder
|
||||||
|
system_prompt = next(
|
||||||
|
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||||
|
# Prepare system prompt content
|
||||||
|
system_prompt_content = (
|
||||||
|
structured_output_prompt + "\n\n" + system_prompt.content
|
||||||
|
if system_prompt and isinstance(system_prompt.content, str)
|
||||||
|
else structured_output_prompt
|
||||||
|
)
|
||||||
|
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||||
|
|
||||||
|
# Extract content from the last user message
|
||||||
|
|
||||||
|
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||||
|
updated_prompt = [system_prompt] + filtered_prompts
|
||||||
|
|
||||||
|
return updated_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||||
|
structured_output: Mapping[str, Any] = {}
|
||||||
|
parsed: Mapping[str, Any] = {}
|
||||||
|
try:
|
||||||
|
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = parsed
|
||||||
|
except ValidationError:
|
||||||
|
# if the result_text is not a valid json, try to repair it
|
||||||
|
temp_parsed = json_repair.loads(result_text)
|
||||||
|
if not isinstance(temp_parsed, dict):
|
||||||
|
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||||
|
if isinstance(temp_parsed, list):
|
||||||
|
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||||
|
else:
|
||||||
|
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = cast(dict, temp_parsed)
|
||||||
|
return structured_output
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
||||||
|
"""
|
||||||
|
Prepare JSON schema based on model requirements.
|
||||||
|
|
||||||
|
Different models have different requirements for JSON schema formatting.
|
||||||
|
This function handles these differences.
|
||||||
|
|
||||||
|
:param schema: The original JSON schema
|
||||||
|
:return: Processed schema compatible with the current model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Deep copy to avoid modifying the original schema
|
||||||
|
processed_schema = dict(deepcopy(schema))
|
||||||
|
|
||||||
|
# Convert boolean types to string types (common requirement)
|
||||||
|
convert_boolean_to_string(processed_schema)
|
||||||
|
|
||||||
|
# Apply model-specific transformations
|
||||||
|
if SpecialModelType.GEMINI in model_schema.model:
|
||||||
|
remove_additional_properties(processed_schema)
|
||||||
|
return processed_schema
|
||||||
|
elif SpecialModelType.OLLAMA in provider:
|
||||||
|
return processed_schema
|
||||||
|
else:
|
||||||
|
# Default format with name field
|
||||||
|
return {"schema": processed_schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_additional_properties(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Remove additionalProperties fields from JSON schema.
|
||||||
|
Used for models like Gemini that don't support this property.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove additionalProperties at current level
|
||||||
|
schema.pop("additionalProperties", None)
|
||||||
|
|
||||||
|
# Process nested structures recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
remove_additional_properties(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
remove_additional_properties(item)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_boolean_to_string(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Convert boolean type specifications to string in JSON schema.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for boolean type at current level
|
||||||
|
if schema.get("type") == "boolean":
|
||||||
|
schema["type"] = "string"
|
||||||
|
|
||||||
|
# Process nested dictionaries and lists recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
convert_boolean_to_string(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
convert_boolean_to_string(item)
|
||||||
@ -0,0 +1,45 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin import GenericProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||||
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicSelectClient(BasePluginClient):
|
||||||
|
def fetch_dynamic_select_options(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
action: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
parameter: str,
|
||||||
|
) -> PluginDynamicSelectOptionsResponse:
|
||||||
|
"""
|
||||||
|
Fetch dynamic select options for a plugin parameter.
|
||||||
|
"""
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options",
|
||||||
|
PluginDynamicSelectOptionsResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": GenericProviderID(provider).provider_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"provider_action": action,
|
||||||
|
"parameter": parameter,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for options in response:
|
||||||
|
return options
|
||||||
|
|
||||||
|
raise ValueError("Plugin service returned no options")
|
||||||
@ -1,8 +1,26 @@
|
|||||||
|
import json
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
|
from .segment_group import SegmentGroup
|
||||||
|
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||||
|
|
||||||
|
|
||||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||||
selectors = [node_id, name]
|
selectors = [node_id, name]
|
||||||
if paths:
|
if paths:
|
||||||
selectors.extend(paths)
|
selectors.extend(paths)
|
||||||
return selectors
|
return selectors
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentJSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if isinstance(o, ArrayFileSegment):
|
||||||
|
return [v.model_dump() for v in o.value]
|
||||||
|
elif isinstance(o, FileSegment):
|
||||||
|
return o.value.model_dump()
|
||||||
|
elif isinstance(o, SegmentGroup):
|
||||||
|
return [self.default(seg) for seg in o.value]
|
||||||
|
elif isinstance(o, Segment):
|
||||||
|
return o.value
|
||||||
|
else:
|
||||||
|
super().default(o)
|
||||||
|
|||||||
@ -0,0 +1,39 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from core.variables import Variable
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableUpdater(Protocol):
|
||||||
|
"""
|
||||||
|
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
|
||||||
|
|
||||||
|
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
|
||||||
|
conversation variables.
|
||||||
|
|
||||||
|
Implementations may choose to batch updates. If batching is used, the `flush` method
|
||||||
|
should be implemented to persist buffered changes, and `update`
|
||||||
|
should handle buffering accordingly.
|
||||||
|
|
||||||
|
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
|
||||||
|
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||||
|
"""
|
||||||
|
Updates the value of the specified conversation variable in the underlying storage.
|
||||||
|
|
||||||
|
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||||
|
:param variable: The `Variable` instance containing the updated value.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def flush(self):
|
||||||
|
"""
|
||||||
|
Flushes all pending updates to the underlying storage system.
|
||||||
|
|
||||||
|
If the implementation does not buffer updates, this method can be a no-op.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@ -1,19 +1,55 @@
|
|||||||
from sqlalchemy import select
|
from collections.abc import Mapping, MutableMapping, Sequence
|
||||||
from sqlalchemy.orm import Session
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from core.variables import Variable
|
from pydantic import BaseModel
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models import ConversationVariable
|
|
||||||
|
|
||||||
|
from core.variables import Segment
|
||||||
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
|
||||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
# Use double underscore (`__`) prefix for internal variables
|
||||||
stmt = select(ConversationVariable).where(
|
# to minimize risk of collision with user-defined variable names.
|
||||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
_UPDATED_VARIABLES_KEY = "__updated_variables"
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatedVariable(BaseModel):
|
||||||
|
name: str
|
||||||
|
selector: Sequence[str]
|
||||||
|
value_type: SegmentType
|
||||||
|
new_value: Any
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||||
|
|
||||||
|
|
||||||
|
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||||
|
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||||
|
raise Exception("selector too short")
|
||||||
|
node_id, var_name = selector[:2]
|
||||||
|
return UpdatedVariable(
|
||||||
|
name=var_name,
|
||||||
|
selector=list(selector[:2]),
|
||||||
|
value_type=seg.value_type,
|
||||||
|
new_value=seg.value,
|
||||||
)
|
)
|
||||||
with Session(db.engine) as session:
|
|
||||||
row = session.scalar(stmt)
|
|
||||||
if not row:
|
def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
|
||||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
m[_UPDATED_VARIABLES_KEY] = updates
|
||||||
row.data = variable.model_dump_json()
|
return m
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
|
||||||
|
updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
|
||||||
|
if updated_values is None:
|
||||||
|
return None
|
||||||
|
result = []
|
||||||
|
for items in updated_values:
|
||||||
|
if isinstance(items, UpdatedVariable):
|
||||||
|
result.append(items)
|
||||||
|
elif isinstance(items, dict):
|
||||||
|
items = UpdatedVariable.model_validate(items)
|
||||||
|
result.append(items)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
|
||||||
|
return result
|
||||||
|
|||||||
@ -0,0 +1,38 @@
|
|||||||
|
from sqlalchemy import Engine, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.variables.variables import Variable
|
||||||
|
from models.engine import db
|
||||||
|
from models.workflow import ConversationVariable
|
||||||
|
|
||||||
|
from .exc import VariableOperatorNodeError
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableUpdaterImpl:
|
||||||
|
_engine: Engine | None
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine | None = None) -> None:
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
|
def _get_engine(self) -> Engine:
|
||||||
|
if self._engine:
|
||||||
|
return self._engine
|
||||||
|
return db.engine
|
||||||
|
|
||||||
|
def update(self, conversation_id: str, variable: Variable):
|
||||||
|
stmt = select(ConversationVariable).where(
|
||||||
|
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
with Session(self._get_engine()) as session:
|
||||||
|
row = session.scalar(stmt)
|
||||||
|
if not row:
|
||||||
|
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||||
|
row.data = variable.model_dump_json()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||||
|
return ConversationVariableUpdaterImpl()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue