feat(api): use `Array*Segment` for array types in `NodeRunResult.outputs` if possible.

Previously, the `outputs` in `NodeRunResult` were primitive Python types
such as `list` and `dict`. This brings challenges to type detection as
the type information has been removed during the output persistence
process. (especially for empty list, since we cannot infer its type by
peeking at its elements)

This commit change the outputs to `Array*Segment`, ensure that the type
of output arrays are perserved.
pull/20699/head
QuantumGhost 11 months ago
parent b46a56c272
commit 3508afcd32

@ -1,6 +1,4 @@
exclude = [
"migrations/*",
]
exclude = ["migrations/*"]
line-length = 120
[format]
@ -9,14 +7,14 @@ quote-style = "double"
[lint]
preview = false
select = [
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
"E", # pycodestyle E rules
"F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules
"N", # pep8-naming
"PT", # flake8-pytest-style rules
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
"E", # pycodestyle E rules
"F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules
"N", # pep8-naming
"PT", # flake8-pytest-style rules
"PLC0208", # iteration-over-set
"PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
@ -24,19 +22,19 @@ select = [
"PLR0402", # manual-from-import
"PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison
"RUF013", # implicit-optional
"RUF019", # unnecessary-key-check
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
"RUF013", # implicit-optional
"RUF019", # unnecessary-key-check
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
@ -47,36 +45,37 @@ select = [
]
ignore = [
"E402", # module-import-not-at-top-of-file
"E711", # none-comparison
"E712", # true-false-comparison
"E721", # type-comparison
"E722", # bare-except
"F821", # undefined-name
"F841", # unused-variable
"E402", # module-import-not-at-top-of-file
"E711", # none-comparison
"E712", # true-false-comparison
"E721", # type-comparison
"E722", # bare-except
"F821", # undefined-name
"F841", # unused-variable
"FURB113", # repeated-append
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
"PT011", # pytest-raises-too-broad
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
"PT011", # pytest-raises-too-broad
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]

@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@ -125,7 +126,7 @@ class WorkflowResponseConverter:
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
@ -202,6 +203,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -214,7 +217,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -245,6 +248,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -257,7 +262,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -376,6 +381,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -384,7 +390,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -463,7 +469,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},

@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import (
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@ -165,7 +166,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.outputs = (
json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs))
if domain_model.outputs
else None
)
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens

@ -19,7 +19,7 @@ from core.workflow.entities.workflow_node_execution import (
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from libs.jsonutil import PydanticModelEncoder
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@ -147,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
json_converter = WorkflowRuntimeTypeConverter()
db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
@ -161,11 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type
db_model.title = domain_model.title
db_model.inputs = json.dumps(domain_model.inputs, cls=PydanticModelEncoder) if domain_model.inputs else None
db_model.inputs = (
json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
)
db_model.process_data = (
json.dumps(domain_model.process_data, cls=PydanticModelEncoder) if domain_model.process_data else None
json.dumps(json_converter.to_json_encodable(domain_model.process_data))
if domain_model.process_data
else None
)
db_model.outputs = (
json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
)
db_model.outputs = json.dumps(domain_model.outputs, cls=PydanticModelEncoder) if domain_model.outputs else None
db_model.status = domain_model.status
db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time

@ -18,3 +18,17 @@ class SegmentType(StrEnum):
NONE = "none"
GROUP = "group"
def is_array_type(self):
return self in _ARRAY_TYPES
_ARRAY_TYPES = frozenset(
[
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)

@ -49,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
)
@classmethod
def _extract_variable_selector_to_variable_mapping(

@ -130,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
prefix: str = "",
depth: int = 1,
):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")

@ -24,7 +24,7 @@ from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -71,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text_list},
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)

@ -6,6 +6,7 @@ from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -170,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
return mapping
def extract_files(self, url: str, response: Response) -> list[File]:
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
"""
Extract files from response by checking both Content-Type header and URL
"""
@ -182,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content_disposition_type = None
if not is_file:
return files
return ArrayFileSegment(value=[])
if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename()
@ -215,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
files.append(file)
return files
return ArrayFileSegment(value=files)

@ -11,6 +11,7 @@ from flask import Flask, current_app
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
)
@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.flask_utils import preserve_flask_contexts
from .exc import (
@ -89,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": []},
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
return
@ -235,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
yield IterationRunSucceededEvent(
iteration_id=self.id,
@ -251,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
outputs={"output": output_segment},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,

@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs, # type: ignore
)
except KnowledgeRetrievalNodeError as e:

@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -34,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@ -75,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
variable = self._apply_slice(variable)
outputs = {
"result": variable.value,
"result": variable,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}

@ -259,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if structured_output:
outputs["structured_output"] = structured_output
if self._file_outputs is not None:
outputs["files"] = self._file_outputs
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
yield RunCompletedEvent(
run_result=NodeRunResult(

@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class _ParameterConfigError(Exception):
pass
class ParameterConfig(BaseModel):
"""
Parameter Config.
@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
class ParameterExtractorNodeData(BaseNodeData):
"""

@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
@ -588,28 +590,30 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith("array"):
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1]
transformed_result[parameter.name] = []
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
transformed_result[parameter.name].append(float(item))
segment_value.value.append(float(item))
else:
transformed_result[parameter.name].append(int(item))
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
if isinstance(item, str):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
elif nested_type == "object":
if isinstance(item, dict):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
@ -619,7 +623,9 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
transformed_result[parameter.name] = []
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
return transformed_result

@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@ -304,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]):
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
@ -367,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

@ -22,7 +22,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable.to_object()}
outputs = {"output": variable}
inputs = {".".join(selector[1:]): variable.to_object()}
break
@ -32,7 +32,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs[group.group_name] = {"output": variable.to_object()}
outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object()
break

@ -92,7 +92,7 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs)
# outputs = WorkflowEntry.handle_special_values(outputs)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
@ -125,7 +125,7 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
@ -242,9 +242,9 @@ class WorkflowCycleManager:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
@ -289,7 +289,7 @@ class WorkflowCycleManager:
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
@ -326,7 +326,7 @@ class WorkflowCycleManager:
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
outputs = event.outputs
# Convert metadata keys to strings
origin_metadata = {

@ -190,6 +190,13 @@ class WorkflowEntry:
# run node
generator = node_instance.run()
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
@ -292,6 +299,12 @@ class WorkflowEntry:
return node_instance, generator
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@staticmethod

@ -0,0 +1,49 @@
import json
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
from core.file.models import File
from core.variables import Segment
class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
def default(self, o: Any):
if isinstance(o, Segment):
return o.value
elif isinstance(o, File):
return o.to_dict()
elif isinstance(o, BaseModel):
return o.model_dump(mode="json")
else:
return super().default(o)
class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
def _to_json_encodable_recursive(self, value: Any) -> Any:
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
return res_list
return value

@ -422,32 +422,33 @@ class StorageKeyLoader:
upload_file_ids: list[uuid.UUID] = []
tool_file_ids: list[uuid.UUID] = []
for file in files:
if file.id is None:
related_model_id = file.related_id
if file.related_id is None:
raise ValueError("file id should not be None.")
if file.tenant_id != self._tenant_id:
err_msg = (
f"invalid file, expected tenant_id={self._tenant_id}, "
f"got tenant_id={file.tenant_id}, file_id={file.id}"
f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
)
raise ValueError(err_msg)
file_id = uuid.UUID(file.id)
model_id = uuid.UUID(related_model_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_ids.append(file_id)
upload_file_ids.append(model_id)
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_ids.append(file_id)
tool_file_ids.append(model_id)
tool_files = self._load_tool_files(tool_file_ids)
upload_files = self._load_upload_files(upload_file_ids)
for file in files:
file_id = uuid.UUID(file.id)
model_id = uuid.UUID(file.related_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(file_id)
upload_file_row = upload_files.get(model_id)
if upload_file_row is None:
raise ValueError(...)
file._storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(file_id)
tool_file_row = tool_files.get(model_id)
if tool_file_row is None:
raise ValueError(...)
file._storage_key = tool_file_row.file_key

@ -662,8 +662,10 @@ class DraftVariableSaver:
self._node_type,
)
continue
value_seg = _build_segment_for_serialized_values(value)
if isinstance(value, Segment):
value_seg = value
else:
value_seg = _build_segment_for_serialized_values(value)
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,

@ -460,7 +460,7 @@ class WorkflowService:
node_run_result = event.run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
# node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
@ -522,7 +522,7 @@ class WorkflowService:
if node_run_result.process_data
else None
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
outputs = node_run_result.outputs
node_execution.inputs = inputs
node_execution.process_data = process_data

@ -101,26 +101,29 @@ class TestStorageKeyLoader(unittest.TestCase):
return tool_file
def _create_file(self, file_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None) -> File:
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
related_id = None
file_related_id = None
remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
related_id = file_id
file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
file_related_id = related_id
return File(
id=file_id,
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=related_id,
related_id=file_related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
@ -133,7 +136,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
@ -145,7 +148,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.REMOTE_URL)
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
# Load storage keys
self.loader.load_storage_keys([file])
@ -157,7 +160,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE)
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
@ -172,9 +175,9 @@ class TestStorageKeyLoader(unittest.TestCase):
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()
file1 = self._create_file(upload_file1.id, FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(upload_file2.id, FileTransferMethod.REMOTE_URL)
file3 = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE)
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
files = [file1, file2, file3]
@ -195,7 +198,9 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test tenant_id validation."""
# Create file with different tenant_id
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()))
file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
)
# Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context:
@ -204,12 +209,12 @@ class TestStorageKeyLoader(unittest.TestCase):
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_missing_file_id(self):
"""Test with None file.id."""
# Create a file with valid parameters first, then manually set id to None
file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE)
file.id = None
"""Test with None file.related_id."""
# Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = None
# Should raise ValueError for None file id
# Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
@ -219,7 +224,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(non_existent_id, FileTransferMethod.LOCAL_FILE)
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
@ -229,7 +234,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(non_existent_id, FileTransferMethod.TOOL_FILE)
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
@ -237,9 +242,9 @@ class TestStorageKeyLoader(unittest.TestCase):
def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid id
file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE)
file.id = "invalid-uuid-format"
# Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID
with pytest.raises(ValueError):
@ -252,8 +257,12 @@ class TestStorageKeyLoader(unittest.TestCase):
tool_files = [self._create_tool_file() for _ in range(2)]
files = []
files.extend([self._create_file(uf.id, FileTransferMethod.LOCAL_FILE) for uf in upload_files])
files.extend([self._create_file(tf.id, FileTransferMethod.TOOL_FILE) for tf in tool_files])
files.extend(
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)
# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
@ -275,7 +284,9 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE)
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
@ -296,7 +307,9 @@ class TestStorageKeyLoader(unittest.TestCase):
self.session.flush()
# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(upload_file_other.id, FileTransferMethod.LOCAL_FILE, other_tenant_id)
file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
@ -312,11 +325,15 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE)
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE, other_tenant_id)
file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
@ -329,9 +346,9 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file
upload_file = self._create_upload_file()
# Create two File objects with same ID
file1 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
# Create two File objects with same related_id
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])
@ -344,7 +361,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Create loader with different session (same underlying connection)

Loading…
Cancel
Save