Merge branch 'main' into feat/r2
# Conflicts: # api/core/plugin/impl/oauth.py # api/core/workflow/entities/variable_pool.py # api/models/workflow.py # api/services/dataset_service.pyfeat/r2
commit
540096a8d8
@ -0,0 +1,421 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
DraftWorkflowNotExist,
|
||||||
|
)
|
||||||
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
|
from core.variables.segment_group import SegmentGroup
|
||||||
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
|
from factories.variable_factory import build_segment_with_type
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models import App, AppMode, db
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||||
|
if isinstance(value, FileSegment):
|
||||||
|
return value.value.model_dump()
|
||||||
|
elif isinstance(value, ArrayFileSegment):
|
||||||
|
return [i.model_dump() for i in value.value]
|
||||||
|
elif isinstance(value, SegmentGroup):
|
||||||
|
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||||
|
else:
|
||||||
|
return value.value
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||||
|
value = variable.get_value()
|
||||||
|
# create a copy of the value to avoid affecting the model cache.
|
||||||
|
value = value.model_copy(deep=True)
|
||||||
|
# Refresh the url signature before returning it to client.
|
||||||
|
if isinstance(value, FileSegment):
|
||||||
|
file = value.value
|
||||||
|
file.remote_url = file.generate_url()
|
||||||
|
elif isinstance(value, ArrayFileSegment):
|
||||||
|
files = value.value
|
||||||
|
for file in files:
|
||||||
|
file.remote_url = file.generate_url()
|
||||||
|
return _convert_values_to_json_serializable_object(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pagination_parser():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"page",
|
||||||
|
type=inputs.int_range(1, 100_000),
|
||||||
|
required=False,
|
||||||
|
default=1,
|
||||||
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
|
"value_type": fields.String,
|
||||||
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
|
"visible": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||||
|
value=fields.Raw(attribute=_serialize_var_value),
|
||||||
|
)
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String(attribute=lambda _: "env"),
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
|
"value_type": fields.String,
|
||||||
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
|
"visible": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||||
|
return var_list.variables
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||||
|
"total": fields.Raw(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _api_prerequisite(f):
|
||||||
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
|
It ensures the following conditions are satisfied:
|
||||||
|
|
||||||
|
- Dify has been property setup.
|
||||||
|
- The request user has logged in and initialized.
|
||||||
|
- The requested app is a workflow or a chat flow.
|
||||||
|
- The request user has the edit permission for the app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
parser = _create_pagination_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||||
|
if not workflow_exist:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||||
|
app_id=app_model.id,
|
||||||
|
page=args.page,
|
||||||
|
limit=args.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||||
|
if node_id in [
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
]:
|
||||||
|
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||||
|
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||||
|
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||||
|
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||||
|
#
|
||||||
|
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||||
|
|
||||||
|
raise InvalidArgumentError(
|
||||||
|
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||||
|
|
||||||
|
return node_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
srv = WorkflowDraftVariableService(db.session())
|
||||||
|
srv.delete_node_variables(app_model.id, node_id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableApi(Resource):
|
||||||
|
_PATCH_NAME_FIELD = "name"
|
||||||
|
_PATCH_VALUE_FIELD = "value"
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def get(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def patch(self, app_model: App, variable_id: str):
|
||||||
|
# Request payload for file types:
|
||||||
|
#
|
||||||
|
# Local File:
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "type": "image",
|
||||||
|
# "transfer_method": "local_file",
|
||||||
|
# "url": "",
|
||||||
|
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Remote File:
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "type": "image",
|
||||||
|
# "transfer_method": "remote_url",
|
||||||
|
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||||
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
|
# }
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
|
# Parse 'value' field as-is to maintain its original data structure
|
||||||
|
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
args = parser.parse_args(strict=True)
|
||||||
|
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||||
|
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||||
|
if new_name is None and raw_value is None:
|
||||||
|
return variable
|
||||||
|
|
||||||
|
new_value = None
|
||||||
|
if raw_value is not None:
|
||||||
|
if variable.value_type == SegmentType.FILE:
|
||||||
|
if not isinstance(raw_value, dict):
|
||||||
|
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||||
|
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||||
|
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||||
|
if not isinstance(raw_value, list):
|
||||||
|
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||||
|
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||||
|
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||||
|
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||||
|
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||||
|
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||||
|
db.session.commit()
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
draft_var_srv.delete_variable(variable)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableResetApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def put(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_srv = WorkflowService()
|
||||||
|
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||||
|
if draft_workflow is None:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Draft workflow not found, app_id={app_model.id}",
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||||
|
db.session.commit()
|
||||||
|
if resetted is None:
|
||||||
|
return Response("", 204)
|
||||||
|
else:
|
||||||
|
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||||
|
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||||
|
else:
|
||||||
|
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||||
|
return draft_vars
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||||
|
# so their IDs can be returned to the caller.
|
||||||
|
workflow_srv = WorkflowService()
|
||||||
|
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||||
|
if draft_workflow is None:
|
||||||
|
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||||
|
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||||
|
db.session.commit()
|
||||||
|
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||||
|
if workflow is None:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
env_vars = workflow.environment_variables
|
||||||
|
env_vars_list = []
|
||||||
|
for v in env_vars:
|
||||||
|
env_vars_list.append(
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"type": "env",
|
||||||
|
"name": v.name,
|
||||||
|
"description": v.description,
|
||||||
|
"selector": v.selector,
|
||||||
|
"value_type": v.value_type.value,
|
||||||
|
"value": v.value,
|
||||||
|
# Do not track edited for env vars.
|
||||||
|
"edited": False,
|
||||||
|
"visible": True,
|
||||||
|
"editable": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"items": env_vars_list}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
WorkflowVariableCollectionApi,
|
||||||
|
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||||
|
)
|
||||||
|
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||||
|
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||||
|
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||||
|
|
||||||
|
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||||
|
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||||
|
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||||
@ -1 +1,11 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||||
|
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||||
|
# this logic into a dedicated function. This would encapsulate the implementation
|
||||||
|
# details of how different variable types are identified.
|
||||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_file_object(o: Any) -> bool:
|
||||||
|
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||||
|
|||||||
@ -1,8 +1,26 @@
|
|||||||
|
import json
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
|
from .segment_group import SegmentGroup
|
||||||
|
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||||
|
|
||||||
|
|
||||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||||
selectors = [node_id, name]
|
selectors = [node_id, name]
|
||||||
if paths:
|
if paths:
|
||||||
selectors.extend(paths)
|
selectors.extend(paths)
|
||||||
return selectors
|
return selectors
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentJSONEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if isinstance(o, ArrayFileSegment):
|
||||||
|
return [v.model_dump() for v in o.value]
|
||||||
|
elif isinstance(o, FileSegment):
|
||||||
|
return o.value.model_dump()
|
||||||
|
elif isinstance(o, SegmentGroup):
|
||||||
|
return [self.default(seg) for seg in o.value]
|
||||||
|
elif isinstance(o, Segment):
|
||||||
|
return o.value
|
||||||
|
else:
|
||||||
|
super().default(o)
|
||||||
|
|||||||
@ -0,0 +1,39 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from core.variables import Variable
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableUpdater(Protocol):
|
||||||
|
"""
|
||||||
|
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
|
||||||
|
|
||||||
|
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
|
||||||
|
conversation variables.
|
||||||
|
|
||||||
|
Implementations may choose to batch updates. If batching is used, the `flush` method
|
||||||
|
should be implemented to persist buffered changes, and `update`
|
||||||
|
should handle buffering accordingly.
|
||||||
|
|
||||||
|
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
|
||||||
|
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||||
|
"""
|
||||||
|
Updates the value of the specified conversation variable in the underlying storage.
|
||||||
|
|
||||||
|
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||||
|
:param variable: The `Variable` instance containing the updated value.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def flush(self):
|
||||||
|
"""
|
||||||
|
Flushes all pending updates to the underlying storage system.
|
||||||
|
|
||||||
|
If the implementation does not buffer updates, this method can be a no-op.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@ -1,19 +1,55 @@
|
|||||||
from sqlalchemy import select
|
from collections.abc import Mapping, MutableMapping, Sequence
|
||||||
from sqlalchemy.orm import Session
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from core.variables import Variable
|
from pydantic import BaseModel
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models import ConversationVariable
|
|
||||||
|
|
||||||
|
from core.variables import Segment
|
||||||
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
|
||||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
# Use double underscore (`__`) prefix for internal variables
|
||||||
stmt = select(ConversationVariable).where(
|
# to minimize risk of collision with user-defined variable names.
|
||||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
_UPDATED_VARIABLES_KEY = "__updated_variables"
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatedVariable(BaseModel):
|
||||||
|
name: str
|
||||||
|
selector: Sequence[str]
|
||||||
|
value_type: SegmentType
|
||||||
|
new_value: Any
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||||
|
|
||||||
|
|
||||||
|
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||||
|
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||||
|
raise Exception("selector too short")
|
||||||
|
node_id, var_name = selector[:2]
|
||||||
|
return UpdatedVariable(
|
||||||
|
name=var_name,
|
||||||
|
selector=list(selector[:2]),
|
||||||
|
value_type=seg.value_type,
|
||||||
|
new_value=seg.value,
|
||||||
)
|
)
|
||||||
with Session(db.engine) as session:
|
|
||||||
row = session.scalar(stmt)
|
|
||||||
if not row:
|
def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
|
||||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
m[_UPDATED_VARIABLES_KEY] = updates
|
||||||
row.data = variable.model_dump_json()
|
return m
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
|
||||||
|
updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
|
||||||
|
if updated_values is None:
|
||||||
|
return None
|
||||||
|
result = []
|
||||||
|
for items in updated_values:
|
||||||
|
if isinstance(items, UpdatedVariable):
|
||||||
|
result.append(items)
|
||||||
|
elif isinstance(items, dict):
|
||||||
|
items = UpdatedVariable.model_validate(items)
|
||||||
|
result.append(items)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
|
||||||
|
return result
|
||||||
|
|||||||
@ -0,0 +1,38 @@
|
|||||||
|
from sqlalchemy import Engine, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.variables.variables import Variable
|
||||||
|
from models.engine import db
|
||||||
|
from models.workflow import ConversationVariable
|
||||||
|
|
||||||
|
from .exc import VariableOperatorNodeError
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableUpdaterImpl:
|
||||||
|
_engine: Engine | None
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine | None = None) -> None:
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
|
def _get_engine(self) -> Engine:
|
||||||
|
if self._engine:
|
||||||
|
return self._engine
|
||||||
|
return db.engine
|
||||||
|
|
||||||
|
def update(self, conversation_id: str, variable: Variable):
|
||||||
|
stmt = select(ConversationVariable).where(
|
||||||
|
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
with Session(self._get_engine()) as session:
|
||||||
|
row = session.scalar(stmt)
|
||||||
|
if not row:
|
||||||
|
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||||
|
row.data = variable.model_dump_json()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||||
|
return ConversationVariableUpdaterImpl()
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
from core.variables.segments import ObjectSegment, Segment
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
|
|
||||||
|
|
||||||
|
def append_variables_recursively(
|
||||||
|
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Append variables recursively
|
||||||
|
:param node_id: node id
|
||||||
|
:param variable_key_list: variable key list
|
||||||
|
:param variable_value: variable value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pool.add([node_id] + variable_key_list, variable_value)
|
||||||
|
|
||||||
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
if isinstance(variable_value, ObjectSegment):
|
||||||
|
variable_dict = variable_value.value
|
||||||
|
elif isinstance(variable_value, dict):
|
||||||
|
variable_dict = variable_value
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
for key, value in variable_dict.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = variable_key_list + [key]
|
||||||
|
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)
|
||||||
@ -0,0 +1,84 @@
|
|||||||
|
import abc
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from core.variables import Variable
|
||||||
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.utils import variable_utils
|
||||||
|
|
||||||
|
|
||||||
|
class VariableLoader(Protocol):
|
||||||
|
"""Interface for loading variables based on selectors.
|
||||||
|
|
||||||
|
A `VariableLoader` is responsible for retrieving additional variables required during the execution
|
||||||
|
of a single node, which are not provided as user inputs.
|
||||||
|
|
||||||
|
NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
|
||||||
|
application and share the same `app_id`. However, this interface does not enforce that constraint,
|
||||||
|
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
|
||||||
|
concern and allow for flexible implementations.
|
||||||
|
|
||||||
|
Implementations of `VariableLoader` should almost always have an `app_id` parameter in
|
||||||
|
their constructor.
|
||||||
|
|
||||||
|
TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
|
||||||
|
`WorkflowService.single_step_run`, we may get rid of this interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||||
|
this method should return an empty list.
|
||||||
|
|
||||||
|
The order of the returned variables is not guaranteed. If the caller wants to ensure
|
||||||
|
a specific order, they should sort the returned list themselves.
|
||||||
|
|
||||||
|
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||||
|
- the first element is the node ID,
|
||||||
|
- the second element is the variable name.
|
||||||
|
:return: a list of Variable objects that match the provided selectors.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyVariableLoader(VariableLoader):
|
||||||
|
"""A dummy implementation of VariableLoader that does not load any variables.
|
||||||
|
Serves as a placeholder when no variable loading is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
|
||||||
|
|
||||||
|
|
||||||
|
def load_into_variable_pool(
|
||||||
|
variable_loader: VariableLoader,
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
variable_mapping: Mapping[str, Sequence[str]],
|
||||||
|
user_inputs: Mapping[str, Any],
|
||||||
|
):
|
||||||
|
# Loading missing variable from draft var here, and set it into
|
||||||
|
# variable_pool.
|
||||||
|
variables_to_load: list[list[str]] = []
|
||||||
|
for key, selector in variable_mapping.items():
|
||||||
|
# NOTE(QuantumGhost): this logic needs to be in sync with
|
||||||
|
# `WorkflowEntry.mapping_user_inputs_to_variable_pool`.
|
||||||
|
node_variable_list = key.split(".")
|
||||||
|
if len(node_variable_list) < 1:
|
||||||
|
raise ValueError(f"Invalid variable key: {key}. It should have at least one element.")
|
||||||
|
if key in user_inputs:
|
||||||
|
continue
|
||||||
|
node_variable_key = ".".join(node_variable_list[1:])
|
||||||
|
if node_variable_key in user_inputs:
|
||||||
|
continue
|
||||||
|
if variable_pool.get(selector) is None:
|
||||||
|
variables_to_load.append(list(selector))
|
||||||
|
loaded = variable_loader.load_variables(variables_to_load)
|
||||||
|
for var in loaded:
|
||||||
|
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||||
|
variable_utils.append_variables_recursively(
|
||||||
|
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
|
||||||
|
)
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.file.models import File
|
||||||
|
from core.variables import Segment
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o: Any):
|
||||||
|
if isinstance(o, Segment):
|
||||||
|
return o.value
|
||||||
|
elif isinstance(o, File):
|
||||||
|
return o.to_dict()
|
||||||
|
elif isinstance(o, BaseModel):
|
||||||
|
return o.model_dump(mode="json")
|
||||||
|
else:
|
||||||
|
return super().default(o)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRuntimeTypeConverter:
|
||||||
|
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||||
|
result = self._to_json_encodable_recursive(value)
|
||||||
|
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||||
|
|
||||||
|
def _to_json_encodable_recursive(self, value: Any) -> Any:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, (bool, int, str, float)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, Segment):
|
||||||
|
return self._to_json_encodable_recursive(value.value)
|
||||||
|
if isinstance(value, File):
|
||||||
|
return value.to_dict()
|
||||||
|
if isinstance(value, BaseModel):
|
||||||
|
return value.model_dump(mode="json")
|
||||||
|
if isinstance(value, dict):
|
||||||
|
res = {}
|
||||||
|
for k, v in value.items():
|
||||||
|
res[k] = self._to_json_encodable_recursive(v)
|
||||||
|
return res
|
||||||
|
if isinstance(value, list):
|
||||||
|
res_list = []
|
||||||
|
for item in value:
|
||||||
|
res_list.append(self._to_json_encodable_recursive(item))
|
||||||
|
return res_list
|
||||||
|
return value
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
import abc
|
||||||
|
import datetime
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class _NowFunction(Protocol):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# _now_func is a callable with the _NowFunction signature.
|
||||||
|
# Its sole purpose is to abstract time retrieval, enabling
|
||||||
|
# developers to mock this behavior in tests and time-dependent scenarios.
|
||||||
|
_now_func: _NowFunction = datetime.datetime.now
|
||||||
|
|
||||||
|
|
||||||
|
def naive_utc_now() -> datetime.datetime:
|
||||||
|
"""Return a naive datetime object (without timezone information)
|
||||||
|
representing current UTC time.
|
||||||
|
"""
|
||||||
|
return _now_func(datetime.UTC).replace(tzinfo=None)
|
||||||
@ -0,0 +1,11 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticModelEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if isinstance(o, BaseModel):
|
||||||
|
return o.model_dump()
|
||||||
|
else:
|
||||||
|
super().default(o)
|
||||||
@ -0,0 +1,66 @@
|
|||||||
|
"""remove sequence_number from workflow_runs
|
||||||
|
|
||||||
|
Revision ID: 0ab65e1cc7fa
|
||||||
|
Revises: 4474872b0ee6
|
||||||
|
Create Date: 2025-06-19 16:33:13.377215
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '0ab65e1cc7fa'
|
||||||
|
down_revision = '4474872b0ee6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index(batch_op.f('workflow_run_tenant_app_sequence_idx'))
|
||||||
|
batch_op.drop_column('sequence_number')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
# WARNING: This downgrade CANNOT recover the original sequence_number values!
|
||||||
|
# The original sequence numbers are permanently lost after the upgrade.
|
||||||
|
# This downgrade will regenerate sequence numbers based on created_at order,
|
||||||
|
# which may result in different values than the original sequence numbers.
|
||||||
|
#
|
||||||
|
# If you need to preserve original sequence numbers, use the alternative
|
||||||
|
# migration approach that creates a backup table before removal.
|
||||||
|
|
||||||
|
# Step 1: Add sequence_number column as nullable first
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('sequence_number', sa.INTEGER(), autoincrement=False, nullable=True))
|
||||||
|
|
||||||
|
# Step 2: Populate sequence_number values based on created_at order within each app
|
||||||
|
# NOTE: This recreates sequence numbering logic but values will be different
|
||||||
|
# from the original sequence numbers that were removed in the upgrade
|
||||||
|
connection = op.get_bind()
|
||||||
|
connection.execute(sa.text("""
|
||||||
|
UPDATE workflow_runs
|
||||||
|
SET sequence_number = subquery.row_num
|
||||||
|
FROM (
|
||||||
|
SELECT id, ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY tenant_id, app_id
|
||||||
|
ORDER BY created_at, id
|
||||||
|
) as row_num
|
||||||
|
FROM workflow_runs
|
||||||
|
) subquery
|
||||||
|
WHERE workflow_runs.id = subquery.id
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# Step 3: Make the column NOT NULL and add the index
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('sequence_number', nullable=False)
|
||||||
|
batch_op.create_index(batch_op.f('workflow_run_tenant_app_sequence_idx'), ['tenant_id', 'app_id', 'sequence_number'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
"""All these exceptions are not meant to be caught by callers."""
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowDataError(Exception):
|
||||||
|
"""Base class for all workflow data related exceptions.
|
||||||
|
|
||||||
|
This should be used to indicate issues with workflow data integrity, such as
|
||||||
|
no `graph` configuration, missing `nodes` field in `graph` configuration, or
|
||||||
|
similar issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NodeNotFoundError(WorkflowDataError):
|
||||||
|
"""Raised when a node with the specified ID is not found in the workflow."""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str):
|
||||||
|
super().__init__(f"Node with ID '{node_id}' not found in the workflow.")
|
||||||
|
self.node_id = node_id
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue