From 3508afcd3254e0cdd4f341751234dda4f1bbe2c1 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 20 Jun 2025 15:25:32 +0800 Subject: [PATCH] feat(api): use `Array*Segment` for array types in `NodeRunResult.outputs` if possible. Previously, the `outputs` in `NodeRunResult` were primitive Python types such as `list` and `dict`. This brings challenges to type detection as the type information has been removed during the output persistence process. (especially for empty list, since we cannot infer its type by peeking at its elements) This commit change the outputs to `Array*Segment`, ensure that the type of output arrays are perserved. --- api/.ruff.toml | 103 +++++++++--------- .../common/workflow_response_converter.py | 16 ++- ...qlalchemy_workflow_execution_repository.py | 7 +- ...hemy_workflow_node_execution_repository.py | 15 ++- api/core/variables/types.py | 14 +++ api/core/workflow/nodes/answer/answer_node.py | 5 +- api/core/workflow/nodes/code/code_node.py | 3 + .../workflow/nodes/document_extractor/node.py | 4 +- api/core/workflow/nodes/http_request/node.py | 7 +- .../nodes/iteration/iteration_node.py | 14 ++- .../knowledge_retrieval_node.py | 8 +- api/core/workflow/nodes/list_operator/node.py | 9 +- api/core/workflow/nodes/llm/node.py | 2 +- .../nodes/parameter_extractor/entities.py | 17 +++ .../parameter_extractor_node.py | 24 ++-- api/core/workflow/nodes/tool/tool_node.py | 5 +- .../variable_aggregator_node.py | 4 +- api/core/workflow/workflow_cycle_manager.py | 14 +-- api/core/workflow/workflow_entry.py | 13 +++ api/core/workflow/workflow_type_encoder.py | 49 +++++++++ api/factories/file_factory.py | 17 +-- .../workflow_draft_variable_service.py | 6 +- api/services/workflow_service.py | 4 +- .../factories/test_storage_key_loader.py | 81 ++++++++------ 24 files changed, 302 insertions(+), 139 deletions(-) create mode 100644 api/core/workflow/workflow_type_encoder.py diff --git a/api/.ruff.toml b/api/.ruff.toml index facb0d5419..0169613bf8 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,6 +1,4 @@ -exclude = [ - "migrations/*", -] +exclude = ["migrations/*"] line-length = 120 [format] @@ -9,14 +7,14 @@ quote-style = "double" [lint] preview = false select = [ - "B", # flake8-bugbear rules - "C4", # flake8-comprehensions - "E", # pycodestyle E rules - "F", # pyflakes rules - "FURB", # refurb rules - "I", # isort rules - "N", # pep8-naming - "PT", # flake8-pytest-style rules + "B", # flake8-bugbear rules + "C4", # flake8-comprehensions + "E", # pycodestyle E rules + "F", # pyflakes rules + "FURB", # refurb rules + "I", # isort rules + "N", # pep8-naming + "PT", # flake8-pytest-style rules "PLC0208", # iteration-over-set "PLC0414", # useless-import-alias "PLE0604", # invalid-all-object @@ -24,19 +22,19 @@ select = [ "PLR0402", # manual-from-import "PLR1711", # useless-return "PLR1714", # repeated-equality-comparison - "RUF013", # implicit-optional - "RUF019", # unnecessary-key-check - "RUF100", # unused-noqa - "RUF101", # redirected-noqa - "RUF200", # invalid-pyproject-toml - "RUF022", # unsorted-dunder-all - "S506", # unsafe-yaml-load - "SIM", # flake8-simplify rules - "TRY400", # error-instead-of-exception - "TRY401", # verbose-log-message - "UP", # pyupgrade rules - "W191", # tab-indentation - "W605", # invalid-escape-sequence + "RUF013", # implicit-optional + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "RUF200", # invalid-pyproject-toml + "RUF022", # unsorted-dunder-all + "S506", # unsafe-yaml-load + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "TRY401", # verbose-log-message + "UP", # pyupgrade rules + "W191", # tab-indentation + "W605", # invalid-escape-sequence # security related linting rules # RCE proctection (sort of) "S102", # exec-builtin, disallow use of `exec` @@ -47,36 +45,37 @@ select = [ ] ignore = [ - "E402", # module-import-not-at-top-of-file - "E711", # none-comparison - "E712", # true-false-comparison - "E721", # type-comparison - "E722", # bare-except - "F821", # undefined-name - "F841", # unused-variable + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "F821", # undefined-name + "F841", # unused-variable "FURB113", # repeated-append "FURB152", # math-constant - "UP007", # non-pep604-annotation - "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B026", # star-arg-unpacking-after-keyword-arg - "B903", # class-as-data-structure - "B904", # raise-without-from-inside-except - "B905", # zip-without-explicit-strict - "N806", # non-lowercase-variable-in-function - "N815", # mixed-case-variable-in-class-scope - "PT011", # pytest-raises-too-broad - "SIM102", # collapsible-if - "SIM103", # needless-bool - "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally - "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # enumerate-for-loop - "SIM117", # multiple-with-statements - "SIM210", # if-expr-with-true-false + "UP007", # non-pep604-annotation + "UP032", # f-string + "UP045", # non-pep604-annotation-optional + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B903", # class-as-data-structure + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # enumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6f524a5872..cd1d298ca2 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -125,7 +126,7 @@ class WorkflowResponseConverter: id=workflow_execution.id_, workflow_id=workflow_execution.workflow_id, status=workflow_execution.status, - outputs=workflow_execution.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), error=workflow_execution.error_message, elapsed_time=workflow_execution.elapsed_time, total_tokens=workflow_execution.total_tokens, @@ -202,6 +203,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -214,7 +217,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -245,6 +248,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeRetryStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -257,7 +262,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -376,6 +381,7 @@ class WorkflowResponseConverter: workflow_execution_id: str, event: QueueIterationCompletedEvent, ) -> IterationNodeCompletedStreamResponse: + json_converter = WorkflowRuntimeTypeConverter() return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_execution_id, @@ -384,7 +390,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=json_converter.to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -463,7 +469,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e5ead9dc56..3c0ab12bde 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( WorkflowType, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -165,7 +166,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.outputs = ( + json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs)) + if domain_model.outputs + else None + ) db_model.status = domain_model.status db_model.error = domain_model.error_message if domain_model.error_message else None db_model.total_tokens = domain_model.total_tokens diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index f3a245e65b..797cce9354 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -19,7 +19,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from libs.jsonutil import PydanticModelEncoder +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -147,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if not self._creator_user_role: raise ValueError("created_by_role is required in repository constructor") + json_converter = WorkflowRuntimeTypeConverter() db_model = WorkflowNodeExecutionModel() db_model.id = domain_model.id db_model.tenant_id = self._tenant_id @@ -161,11 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) db_model.node_id = domain_model.node_id db_model.node_type = domain_model.node_type db_model.title = domain_model.title - db_model.inputs = json.dumps(domain_model.inputs, cls=PydanticModelEncoder) if domain_model.inputs else None + db_model.inputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None + ) db_model.process_data = ( - json.dumps(domain_model.process_data, cls=PydanticModelEncoder) if domain_model.process_data else None + json.dumps(json_converter.to_json_encodable(domain_model.process_data)) + if domain_model.process_data + else None + ) + db_model.outputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None ) - db_model.outputs = json.dumps(domain_model.outputs, cls=PydanticModelEncoder) if domain_model.outputs else None db_model.status = domain_model.status db_model.error = domain_model.error db_model.elapsed_time = domain_model.elapsed_time diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 4387e9693e..68d3d82883 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -18,3 +18,17 @@ class SegmentType(StrEnum): NONE = "none" GROUP = "group" + + def is_array_type(self): + return self in _ARRAY_TYPES + + +_ARRAY_TYPES = frozenset( + [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] +) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index d7e36aa93e..38c2bcbdf5 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -49,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"answer": answer, "files": ArrayFileSegment(value=files)}, + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index fccb13360c..22ed9e2651 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -130,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]): prefix: str = "", depth: int = 1, ): + # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. + # Note that `_transform_result` may produce lists containing `None` values, + # which don't conform to the type requirements of `Array*Segment` classes. if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 4fc1ed8d9b..9f48b48865 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -24,7 +24,7 @@ from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment -from core.variables.segments import FileSegment +from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -71,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"text": extracted_text_list}, + outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): extracted_text = _extract_text_from_file(value) diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index aa494fdb53..5059e1f191 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -6,6 +6,7 @@ from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -170,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): return mapping - def extract_files(self, url: str, response: Response) -> list[File]: + def extract_files(self, url: str, response: Response) -> ArrayFileSegment: """ Extract files from response by checking both Content-Type header and URL """ @@ -182,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): content_disposition_type = None if not is_file: - return files + return ArrayFileSegment(value=[]) if parsed_content_disposition: content_disposition_filename = parsed_content_disposition.get_filename() @@ -215,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): ) files.append(file) - return files + return ArrayFileSegment(value=files) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5243a54bc4..151efc28ec 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,6 +11,7 @@ from flask import Flask, current_app from configs import dify_config from core.variables import ArrayVariable, IntegerVariable, NoneVariable +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import ( NodeRunResult, ) @@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from factories.variable_factory import build_segment from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -89,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") if isinstance(variable, NoneVariable) or len(variable.value) == 0: + # Try our best to preserve the type informat. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": []}, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, ) ) return @@ -235,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]): # Flatten the list of lists if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): outputs = [item for sublist in outputs for item in sublist] + output_segment = build_segment(outputs) yield IterationRunSucceededEvent( iteration_id=self.id, @@ -251,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": output_segment}, metadata={ WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5cf5848d54..2995f0682f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment +from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType @@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=node_data, query=query) - outputs = {"result": results} + outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs, # type: ignore ) except KnowledgeRetrievalNodeError as e: diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 56e8b20086..3c9ba44cf1 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -34,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): if not variable.value: inputs = {"variable": []} process_data = {"variable": []} - outputs = {"result": [], "first_record": None, "last_record": None} + if isinstance(variable, ArraySegment): + result = variable.model_copy(update={"value": []}) + else: + result = ArrayAnySegment(value=[]) + outputs = {"result": result, "first_record": None, "last_record": None} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -75,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): variable = self._apply_slice(variable) outputs = { - "result": variable.value, + "result": variable, "first_record": variable.value[0] if variable.value else None, "last_record": variable.value[-1] if variable.value else None, } diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c4c9546920..124ae6d75d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -259,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]): if structured_output: outputs["structured_output"] = structured_output if self._file_outputs is not None: - outputs["files"] = self._file_outputs + outputs["files"] = ArrayFileSegment(value=self._file_outputs) yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 369eb13b04..916778d167 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.llm import ModelConfig, VisionConfig +class _ParameterConfigError(Exception): + pass + + class ParameterConfig(BaseModel): """ Parameter Config. @@ -27,6 +31,19 @@ class ParameterConfig(BaseModel): raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return str(value) + def is_array_type(self) -> bool: + return self.type in ("array[string]", "array[number]", "array[object]") + + def element_type(self) -> Literal["string", "number", "object"]: + if self.type == "array[number]": + return "number" + elif self.type == "array[string]": + return "string" + elif self.type == "array[object]": + return "object" + else: + raise _ParameterConfigError(f"{self.type} is not array type.") + class ParameterExtractorNodeData(BaseNodeData): """ diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index bde319ebe2..8d6c2d0a5c 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser +from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( @@ -588,28 +590,30 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith("array"): + elif parameter.is_array_type(): if isinstance(result[parameter.name], list): - nested_type = parameter.type[6:-1] - transformed_result[parameter.name] = [] + nested_type = parameter.element_type() + assert nested_type is not None + segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) + transformed_result[parameter.name] = segment_value for item in result[parameter.name]: if nested_type == "number": if isinstance(item, int | float): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif isinstance(item, str): try: if "." in item: - transformed_result[parameter.name].append(float(item)) + segment_value.value.append(float(item)) else: - transformed_result[parameter.name].append(int(item)) + segment_value.value.append(int(item)) except ValueError: pass elif nested_type == "string": if isinstance(item, str): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif nested_type == "object": if isinstance(item, dict): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) if parameter.name not in transformed_result: if parameter.type == "number": @@ -619,7 +623,9 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): - transformed_result[parameter.name] = [] + transformed_result[parameter.name] = build_segment_with_type( + segment_type=SegmentType(parameter.type), value=[] + ) return transformed_result diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 608ec5e004..aa15d69931 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -304,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]): variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None + assert isinstance(message.meta, File) files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -367,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 4b1f816cad..167805c6ff 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -22,7 +22,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = {"output": variable.to_object()} + outputs = {"output": variable} inputs = {".".join(selector[1:]): variable.to_object()} break @@ -32,7 +32,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = {"output": variable.to_object()} + outputs[group.group_name] = {"output": variable} inputs[".".join(selector[1:])] = variable.to_object() break diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index b88f9edd03..6ee562fc8d 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -92,7 +92,7 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(outputs) + # outputs = WorkflowEntry.handle_special_values(outputs) workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED workflow_execution.outputs = outputs or {} @@ -125,7 +125,7 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED execution.outputs = outputs or {} @@ -242,9 +242,9 @@ class WorkflowCycleManager: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + inputs = event.inputs + process_data = event.process_data + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -289,7 +289,7 @@ class WorkflowCycleManager: # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -326,7 +326,7 @@ class WorkflowCycleManager: finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings origin_metadata = { diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ddf0620077..182c54fa77 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -190,6 +190,13 @@ class WorkflowEntry: # run node generator = node_instance.run() except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + workflow.id, + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) return node_instance, generator @@ -292,6 +299,12 @@ class WorkflowEntry: return node_instance, generator except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @staticmethod diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py new file mode 100644 index 0000000000..0123fdac18 --- /dev/null +++ b/api/core/workflow/workflow_type_encoder.py @@ -0,0 +1,49 @@ +import json +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel + +from core.file.models import File +from core.variables import Segment + + +class WorkflowRuntimeTypeEncoder(json.JSONEncoder): + def default(self, o: Any): + if isinstance(o, Segment): + return o.value + elif isinstance(o, File): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump(mode="json") + else: + return super().default(o) + + +class WorkflowRuntimeTypeConverter: + def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + result = self._to_json_encodable_recursive(value) + return result if isinstance(result, Mapping) or result is None else dict(result) + + def _to_json_encodable_recursive(self, value: Any) -> Any: + if value is None: + return value + if isinstance(value, (bool, int, str, float)): + return value + if isinstance(value, Segment): + return self._to_json_encodable_recursive(value.value) + if isinstance(value, File): + return value.to_dict() + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = self._to_json_encodable_recursive(v) + return res + if isinstance(value, list): + res_list = [] + for item in value: + res_list.append(self._to_json_encodable_recursive(item)) + return res_list + return value diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 5de3e45ef7..e0beef40c6 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -422,32 +422,33 @@ class StorageKeyLoader: upload_file_ids: list[uuid.UUID] = [] tool_file_ids: list[uuid.UUID] = [] for file in files: - if file.id is None: + related_model_id = file.related_id + if file.related_id is None: raise ValueError("file id should not be None.") if file.tenant_id != self._tenant_id: err_msg = ( f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}" + f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" ) raise ValueError(err_msg) - file_id = uuid.UUID(file.id) + model_id = uuid.UUID(related_model_id) if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(file_id) + upload_file_ids.append(model_id) elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(file_id) + tool_file_ids.append(model_id) tool_files = self._load_tool_files(tool_file_ids) upload_files = self._load_upload_files(upload_file_ids) for file in files: - file_id = uuid.UUID(file.id) + model_id = uuid.UUID(file.related_id) if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(file_id) + upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(...) file._storage_key = upload_file_row.key elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(file_id) + tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(...) file._storage_key = tool_file_row.file_key diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 5d3011cbed..cd30440b4f 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -662,8 +662,10 @@ class DraftVariableSaver: self._node_type, ) continue - - value_seg = _build_segment_for_serialized_values(value) + if isinstance(value, Segment): + value_seg = value + else: + value_seg = _build_segment_for_serialized_values(value) draft_vars.append( WorkflowDraftVariable.new_node_variable( app_id=self._app_id, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 53a22c8e76..d52e4302ef 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -460,7 +460,7 @@ class WorkflowService: node_run_result = event.run_result # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) break if not node_run_result: @@ -522,7 +522,7 @@ class WorkflowService: if node_run_result.process_data else None ) - outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + outputs = node_run_result.outputs node_execution.inputs = inputs node_execution.process_data = process_data diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index e98d2be9ec..fecb3f6d95 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -101,26 +101,29 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file(self, file_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None) -> File: + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods - related_id = None + file_related_id = None remote_url = None if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): - related_id = file_id + file_related_id = related_id elif transfer_method == FileTransferMethod.REMOTE_URL: remote_url = "https://example.com/test_file.txt" + file_related_id = related_id return File( - id=file_id, + id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, type=FileType.DOCUMENT, transfer_method=transfer_method, - related_id=related_id, + related_id=file_related_id, remote_url=remote_url, filename="test_file.txt", extension=".txt", @@ -133,7 +136,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test loading storage keys for LOCAL_FILE transfer method.""" # Create test data upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) # Load storage keys self.loader.load_storage_keys([file]) @@ -145,7 +148,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test loading storage keys for REMOTE_URL transfer method.""" # Create test data upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.REMOTE_URL) + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) # Load storage keys self.loader.load_storage_keys([file]) @@ -157,7 +160,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test loading storage keys for TOOL_FILE transfer method.""" # Create test data tool_file = self._create_tool_file() - file = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE) + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) # Load storage keys self.loader.load_storage_keys([file]) @@ -172,9 +175,9 @@ class TestStorageKeyLoader(unittest.TestCase): upload_file2 = self._create_upload_file() tool_file = self._create_tool_file() - file1 = self._create_file(upload_file1.id, FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(upload_file2.id, FileTransferMethod.REMOTE_URL) - file3 = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE) + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) files = [file1, file2, file3] @@ -195,7 +198,9 @@ class TestStorageKeyLoader(unittest.TestCase): """Test tenant_id validation.""" # Create file with different tenant_id upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())) + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) # Should raise ValueError for tenant mismatch with pytest.raises(ValueError) as context: @@ -204,12 +209,12 @@ class TestStorageKeyLoader(unittest.TestCase): assert "invalid file, expected tenant_id" in str(context.value) def test_load_storage_keys_missing_file_id(self): - """Test with None file.id.""" - # Create a file with valid parameters first, then manually set id to None - file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE) - file.id = None + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None - # Should raise ValueError for None file id + # Should raise ValueError for None file related_id with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file]) @@ -219,7 +224,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test with missing UploadFile database records.""" # Create file with non-existent upload file id non_existent_id = str(uuid4()) - file = self._create_file(non_existent_id, FileTransferMethod.LOCAL_FILE) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) # Should raise ValueError for missing record with pytest.raises(ValueError): @@ -229,7 +234,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test with missing ToolFile database records.""" # Create file with non-existent tool file id non_existent_id = str(uuid4()) - file = self._create_file(non_existent_id, FileTransferMethod.TOOL_FILE) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) # Should raise ValueError for missing record with pytest.raises(ValueError): @@ -237,9 +242,9 @@ class TestStorageKeyLoader(unittest.TestCase): def test_load_storage_keys_invalid_uuid(self): """Test with invalid UUID format.""" - # Create a file with valid parameters first, then manually set invalid id - file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE) - file.id = "invalid-uuid-format" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" # Should raise ValueError for invalid UUID with pytest.raises(ValueError): @@ -252,8 +257,12 @@ class TestStorageKeyLoader(unittest.TestCase): tool_files = [self._create_tool_file() for _ in range(2)] files = [] - files.extend([self._create_file(uf.id, FileTransferMethod.LOCAL_FILE) for uf in upload_files]) - files.extend([self._create_file(tf.id, FileTransferMethod.TOOL_FILE) for tf in tool_files]) + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) # Mock the session to count queries with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: @@ -275,7 +284,9 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file for current tenant upload_file_current = self._create_upload_file() - file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE) + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) # Create upload file for other tenant (but don't add to cleanup list) upload_file_other = UploadFile( @@ -296,7 +307,9 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.flush() # Create file for other tenant but try to load with current tenant's loader - file_other = self._create_file(upload_file_other.id, FileTransferMethod.LOCAL_FILE, other_tenant_id) + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) # Should raise ValueError due to tenant mismatch with pytest.raises(ValueError) as context: @@ -312,11 +325,15 @@ class TestStorageKeyLoader(unittest.TestCase): """Test batch with mixed tenant files (should fail on first mismatch).""" # Create files for current tenant upload_file_current = self._create_upload_file() - file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE) + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) # Create file for different tenant other_tenant_id = str(uuid4()) - file_other = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE, other_tenant_id) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) # Should raise ValueError on tenant mismatch with pytest.raises(ValueError) as context: @@ -329,9 +346,9 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file upload_file = self._create_upload_file() - # Create two File objects with same ID - file1 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) # Should handle duplicates gracefully self.loader.load_storage_keys([file1, file2]) @@ -344,7 +361,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test that the loader uses the provided session correctly.""" # Create test data upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) # Create loader with different session (same underlying connection)