fix(api): fix variable type handling in `Start` node.

Fix the issue that `file` and `array[file]` variable in `Start` node are recognized as `object`
type.
pull/20699/head
QuantumGhost 11 months ago
parent ea89d2a17c
commit b58a515f5b

@ -439,6 +439,7 @@ class DraftWorkflowNodeRunApi(Resource):
raise ValueError("Workflow not initialized")
files = _parse_file(draft_workflow, args.get("files"))
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,

@ -104,6 +104,7 @@ class VariableEntity(BaseModel):
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""

@ -120,7 +120,7 @@ class WorkflowEntry:
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
@ -326,7 +326,7 @@ class WorkflowEntry:
cls,
*,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: dict,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:

@ -9,6 +9,7 @@ from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -27,9 +28,11 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
@ -311,7 +314,7 @@ class WorkflowService:
app_model: App,
draft_workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: Mapping[str, Any],
account: Account,
query: str = "",
files: Sequence[File] | None = None,
@ -326,7 +329,8 @@ class WorkflowService:
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = NodeType(node_config.get("data", {}).get("type"))
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
if node_type == NodeType.START:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
@ -335,7 +339,10 @@ class WorkflowService:
app=app_model,
workflow=draft_workflow,
)
start_data = StartNodeData.model_validate(node_data)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
# init variable pool
variable_pool = _setup_variable_pool(
query=query,
@ -362,7 +369,6 @@ class WorkflowService:
app_id=app_model.id,
)
node_config = draft_workflow.get_node_config_by_id(node_id)
eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
if eclosing_node_type_and_id:
_, enclosing_node_id = eclosing_node_type_and_id
@ -699,3 +705,36 @@ def _setup_variable_pool(
)
return variable_pool
def _rebuild_file_for_user_inputs_in_start_node(
tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any]
) -> Mapping[str, Any]:
inputs_copy = dict(user_inputs)
for variable in start_node_data.variables:
if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST):
continue
if variable.variable not in user_inputs:
continue
value = user_inputs[variable.variable]
file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
inputs_copy[variable.variable] = file
return inputs_copy
def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]:
if variable_entity_type == VariableEntityType.FILE:
if not isinstance(value, dict):
raise ValueError(f"expected dict for file object, got {type(value)}")
return build_from_mapping(mapping=value, tenant_id=tenant_id)
elif variable_entity_type == VariableEntityType.FILE_LIST:
if not isinstance(value, list):
raise ValueError(f"expected list for file list object, got {type(value)}")
if len(value) == 0:
return []
if not isinstance(value[0], dict):
raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
return build_from_mappings(mappings=value, tenant_id=tenant_id)
else:
raise Exception("unreachable")

Loading…
Cancel
Save