diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1dae1295e4..a9f088a276 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -439,6 +439,7 @@ class DraftWorkflowNodeRunApi(Resource): raise ValueError("Workflow not initialized") files = _parse_file(draft_workflow, args.get("files")) workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( app_model=app_model, draft_workflow=draft_workflow, diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 3f31b1c3d5..75bd2f677a 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -104,6 +104,7 @@ class VariableEntity(BaseModel): Variable Entity. """ + # `variable` records the name of the variable in user inputs. variable: str label: str description: str = "" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index e8d7760f01..ddf0620077 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -120,7 +120,7 @@ class WorkflowEntry: workflow: Workflow, node_id: str, user_id: str, - user_inputs: dict, + user_inputs: Mapping[str, Any], variable_pool: VariablePool, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: @@ -326,7 +326,7 @@ class WorkflowEntry: cls, *, variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, + user_inputs: Mapping[str, Any], variable_pool: VariablePool, tenant_id: str, ) -> None: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 129ea10141..ac9db869cc 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -27,9 +28,11 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.nodes.start.entities import StartNodeData from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from factories.file_factory import build_from_mapping, build_from_mappings from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -311,7 +314,7 @@ class WorkflowService: app_model: App, draft_workflow: Workflow, node_id: str, - user_inputs: dict, + user_inputs: Mapping[str, Any], account: Account, query: str = "", files: Sequence[File] | None = None, @@ -326,7 +329,8 @@ class WorkflowService: draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = NodeType(node_config.get("data", {}).get("type")) + node_type = Workflow.get_node_type_from_node_config(node_config) + node_data = node_config.get("data", {}) if node_type == NodeType.START: with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) @@ -335,7 +339,10 @@ class WorkflowService: app=app_model, workflow=draft_workflow, ) - + start_data = StartNodeData.model_validate(node_data) + user_inputs = _rebuild_file_for_user_inputs_in_start_node( + tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs + ) # init variable pool variable_pool = _setup_variable_pool( query=query, @@ -362,7 +369,6 @@ class WorkflowService: app_id=app_model.id, ) - node_config = draft_workflow.get_node_config_by_id(node_id) eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) if eclosing_node_type_and_id: _, enclosing_node_id = eclosing_node_type_and_id @@ -699,3 +705,36 @@ def _setup_variable_pool( ) return variable_pool + + +def _rebuild_file_for_user_inputs_in_start_node( + tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any] +) -> Mapping[str, Any]: + inputs_copy = dict(user_inputs) + + for variable in start_node_data.variables: + if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST): + continue + if variable.variable not in user_inputs: + continue + value = user_inputs[variable.variable] + file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type) + inputs_copy[variable.variable] = file + return inputs_copy + + +def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]: + if variable_entity_type == VariableEntityType.FILE: + if not isinstance(value, dict): + raise ValueError(f"expected dict for file object, got {type(value)}") + return build_from_mapping(mapping=value, tenant_id=tenant_id) + elif variable_entity_type == VariableEntityType.FILE_LIST: + if not isinstance(value, list): + raise ValueError(f"expected list for file list object, got {type(value)}") + if len(value) == 0: + return [] + if not isinstance(value[0], dict): + raise ValueError(f"expected dict for first element in the file list, got {type(value)}") + return build_from_mappings(mappings=value, tenant_id=tenant_id) + else: + raise Exception("unreachable")