feat(api): use `Array*Segment` for array types in `NodeRunResult.outputs` if possible.

Previously, the `outputs` in `NodeRunResult` were primitive Python types
such as `list` and `dict`. This brings challenges to type detection as
the type information has been removed during the output persistence
process. (especially for empty list, since we cannot infer its type by
peeking at its elements)

This commit change the outputs to `Array*Segment`, ensure that the type
of output arrays are perserved.
pull/20699/head
QuantumGhost 11 months ago
parent b46a56c272
commit 3508afcd32

@ -1,6 +1,4 @@
exclude = [ exclude = ["migrations/*"]
"migrations/*",
]
line-length = 120 line-length = 120
[format] [format]
@ -77,6 +75,7 @@ ignore = [
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
] ]
[lint.per-file-ignores] [lint.per-file-ignores]

@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@ -125,7 +126,7 @@ class WorkflowResponseConverter:
id=workflow_execution.id_, id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id, workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status, status=workflow_execution.status,
outputs=workflow_execution.outputs, outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message, error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time, elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens, total_tokens=workflow_execution.total_tokens,
@ -202,6 +203,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at: if not workflow_node_execution.finished_at:
return None return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeFinishStreamResponse( return NodeFinishStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id, workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -214,7 +217,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data, process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs, outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
@ -245,6 +248,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at: if not workflow_node_execution.finished_at:
return None return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeRetryStreamResponse( return NodeRetryStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id, workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -257,7 +262,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data, process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs, outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
@ -376,6 +381,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str, workflow_execution_id: str,
event: QueueIterationCompletedEvent, event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_execution_id, workflow_run_id=workflow_execution_id,
@ -384,7 +390,7 @@ class WorkflowResponseConverter:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_data.title,
outputs=event.outputs, outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=event.inputs or {},
@ -463,7 +469,7 @@ class WorkflowResponseConverter:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_data.title,
outputs=event.outputs, outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=event.inputs or {},

@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import (
WorkflowType, WorkflowType,
) )
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@ -165,7 +166,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
db_model.version = domain_model.workflow_version db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None db_model.outputs = (
json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs))
if domain_model.outputs
else None
)
db_model.status = domain_model.status db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens db_model.total_tokens = domain_model.total_tokens

@ -19,7 +19,7 @@ from core.workflow.entities.workflow_node_execution import (
) )
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from libs.jsonutil import PydanticModelEncoder from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@ -147,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role: if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor") raise ValueError("created_by_role is required in repository constructor")
json_converter = WorkflowRuntimeTypeConverter()
db_model = WorkflowNodeExecutionModel() db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id db_model.tenant_id = self._tenant_id
@ -161,11 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.node_id = domain_model.node_id db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type db_model.node_type = domain_model.node_type
db_model.title = domain_model.title db_model.title = domain_model.title
db_model.inputs = json.dumps(domain_model.inputs, cls=PydanticModelEncoder) if domain_model.inputs else None db_model.inputs = (
json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
)
db_model.process_data = ( db_model.process_data = (
json.dumps(domain_model.process_data, cls=PydanticModelEncoder) if domain_model.process_data else None json.dumps(json_converter.to_json_encodable(domain_model.process_data))
if domain_model.process_data
else None
)
db_model.outputs = (
json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
) )
db_model.outputs = json.dumps(domain_model.outputs, cls=PydanticModelEncoder) if domain_model.outputs else None
db_model.status = domain_model.status db_model.status = domain_model.status
db_model.error = domain_model.error db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time db_model.elapsed_time = domain_model.elapsed_time

@ -18,3 +18,17 @@ class SegmentType(StrEnum):
NONE = "none" NONE = "none"
GROUP = "group" GROUP = "group"
def is_array_type(self):
return self in _ARRAY_TYPES
_ARRAY_TYPES = frozenset(
[
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)

@ -49,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
part = cast(TextGenerateRouteChunk, part) part = cast(TextGenerateRouteChunk, part)
answer += part.text answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
)
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(

@ -130,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
prefix: str = "", prefix: str = "",
depth: int = 1, depth: int = 1,
): ):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH: if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")

@ -24,7 +24,7 @@ from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -71,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs={"text": extracted_text_list}, outputs={"text": ArrayStringSegment(value=extracted_text_list)},
) )
elif isinstance(value, File): elif isinstance(value, File):
extracted_text = _extract_text_from_file(value) extracted_text = _extract_text_from_file(value)

@ -6,6 +6,7 @@ from typing import Any, Optional
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -170,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
return mapping return mapping
def extract_files(self, url: str, response: Response) -> list[File]: def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
""" """
Extract files from response by checking both Content-Type header and URL Extract files from response by checking both Content-Type header and URL
""" """
@ -182,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content_disposition_type = None content_disposition_type = None
if not is_file: if not is_file:
return files return ArrayFileSegment(value=[])
if parsed_content_disposition: if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename() content_disposition_filename = parsed_content_disposition.get_filename()
@ -215,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
) )
files.append(file) files.append(file)
return files return ArrayFileSegment(value=files)

@ -11,6 +11,7 @@ from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import ( from core.workflow.entities.node_entities import (
NodeRunResult, NodeRunResult,
) )
@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from .exc import ( from .exc import (
@ -89,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0: if isinstance(variable, NoneVariable) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": []}, # TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
) )
) )
return return
@ -235,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
# Flatten the list of lists # Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist] outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,
@ -251,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs}, outputs={"output": output_segment},
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,

@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import StringSegment
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge # retrieve knowledge
try: try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query) results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results} outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs, # type: ignore
) )
except KnowledgeRetrievalNodeError as e: except KnowledgeRetrievalNodeError as e:

@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from core.file import File from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -34,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
if not variable.value: if not variable.value:
inputs = {"variable": []} inputs = {"variable": []}
process_data = {"variable": []} process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None} if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
@ -75,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
variable = self._apply_slice(variable) variable = self._apply_slice(variable)
outputs = { outputs = {
"result": variable.value, "result": variable,
"first_record": variable.value[0] if variable.value else None, "first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None, "last_record": variable.value[-1] if variable.value else None,
} }

@ -259,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if structured_output: if structured_output:
outputs["structured_output"] = structured_output outputs["structured_output"] = structured_output
if self._file_outputs is not None: if self._file_outputs is not None:
outputs["files"] = self._file_outputs outputs["files"] = ArrayFileSegment(value=self._file_outputs)
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(

@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig from core.workflow.nodes.llm import ModelConfig, VisionConfig
class _ParameterConfigError(Exception):
pass
class ParameterConfig(BaseModel): class ParameterConfig(BaseModel):
""" """
Parameter Config. Parameter Config.
@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
raise ValueError("Invalid parameter name, __reason and __is_success are reserved") raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value) return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
class ParameterExtractorNodeData(BaseNodeData): class ParameterExtractorNodeData(BaseNodeData):
""" """

@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData from .entities import ParameterExtractorNodeData
from .exc import ( from .exc import (
@ -588,28 +590,30 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}: elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str): if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name] transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith("array"): elif parameter.is_array_type():
if isinstance(result[parameter.name], list): if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1] nested_type = parameter.element_type()
transformed_result[parameter.name] = [] assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]: for item in result[parameter.name]:
if nested_type == "number": if nested_type == "number":
if isinstance(item, int | float): if isinstance(item, int | float):
transformed_result[parameter.name].append(item) segment_value.value.append(item)
elif isinstance(item, str): elif isinstance(item, str):
try: try:
if "." in item: if "." in item:
transformed_result[parameter.name].append(float(item)) segment_value.value.append(float(item))
else: else:
transformed_result[parameter.name].append(int(item)) segment_value.value.append(int(item))
except ValueError: except ValueError:
pass pass
elif nested_type == "string": elif nested_type == "string":
if isinstance(item, str): if isinstance(item, str):
transformed_result[parameter.name].append(item) segment_value.value.append(item)
elif nested_type == "object": elif nested_type == "object":
if isinstance(item, dict): if isinstance(item, dict):
transformed_result[parameter.name].append(item) segment_value.value.append(item)
if parameter.name not in transformed_result: if parameter.name not in transformed_result:
if parameter.type == "number": if parameter.type == "number":
@ -619,7 +623,9 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}: elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = "" transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"): elif parameter.type.startswith("array"):
transformed_result[parameter.name] = [] transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
return transformed_result return transformed_result

@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -304,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]):
variables[variable_name] = variable_value variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE: elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"]) files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG: elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage) assert isinstance(message.message, ToolInvokeMessage.LogMessage)
@ -367,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]):
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
metadata={ metadata={
**agent_execution_metadata, **agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

@ -22,7 +22,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
for selector in self.node_data.variables: for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs = {"output": variable.to_object()} outputs = {"output": variable}
inputs = {".".join(selector[1:]): variable.to_object()} inputs = {".".join(selector[1:]): variable.to_object()}
break break
@ -32,7 +32,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs[group.group_name] = {"output": variable.to_object()} outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object() inputs[".".join(selector[1:])] = variable.to_object()
break break

@ -92,7 +92,7 @@ class WorkflowCycleManager:
) -> WorkflowExecution: ) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs) # outputs = WorkflowEntry.handle_special_values(outputs)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {} workflow_execution.outputs = outputs or {}
@ -125,7 +125,7 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution: ) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id) execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {} execution.outputs = outputs or {}
@ -242,9 +242,9 @@ class WorkflowCycleManager:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}") raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data # Process data
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = event.inputs
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = event.process_data
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = event.outputs
# Convert metadata keys to strings # Convert metadata keys to strings
execution_metadata_dict = {} execution_metadata_dict = {}
@ -289,7 +289,7 @@ class WorkflowCycleManager:
# Process data # Process data
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = event.outputs
# Convert metadata keys to strings # Convert metadata keys to strings
execution_metadata_dict = {} execution_metadata_dict = {}
@ -326,7 +326,7 @@ class WorkflowCycleManager:
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds() elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = event.outputs
# Convert metadata keys to strings # Convert metadata keys to strings
origin_metadata = { origin_metadata = {

@ -190,6 +190,13 @@ class WorkflowEntry:
# run node # run node
generator = node_instance.run() generator = node_instance.run()
except Exception as e: except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator return node_instance, generator
@ -292,6 +299,12 @@ class WorkflowEntry:
return node_instance, generator return node_instance, generator
except Exception as e: except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@staticmethod @staticmethod

@ -0,0 +1,49 @@
import json
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
from core.file.models import File
from core.variables import Segment
class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
def default(self, o: Any):
if isinstance(o, Segment):
return o.value
elif isinstance(o, File):
return o.to_dict()
elif isinstance(o, BaseModel):
return o.model_dump(mode="json")
else:
return super().default(o)
class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
def _to_json_encodable_recursive(self, value: Any) -> Any:
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
return res_list
return value

@ -422,32 +422,33 @@ class StorageKeyLoader:
upload_file_ids: list[uuid.UUID] = [] upload_file_ids: list[uuid.UUID] = []
tool_file_ids: list[uuid.UUID] = [] tool_file_ids: list[uuid.UUID] = []
for file in files: for file in files:
if file.id is None: related_model_id = file.related_id
if file.related_id is None:
raise ValueError("file id should not be None.") raise ValueError("file id should not be None.")
if file.tenant_id != self._tenant_id: if file.tenant_id != self._tenant_id:
err_msg = ( err_msg = (
f"invalid file, expected tenant_id={self._tenant_id}, " f"invalid file, expected tenant_id={self._tenant_id}, "
f"got tenant_id={file.tenant_id}, file_id={file.id}" f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
) )
raise ValueError(err_msg) raise ValueError(err_msg)
file_id = uuid.UUID(file.id) model_id = uuid.UUID(related_model_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_ids.append(file_id) upload_file_ids.append(model_id)
elif file.transfer_method == FileTransferMethod.TOOL_FILE: elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_ids.append(file_id) tool_file_ids.append(model_id)
tool_files = self._load_tool_files(tool_file_ids) tool_files = self._load_tool_files(tool_file_ids)
upload_files = self._load_upload_files(upload_file_ids) upload_files = self._load_upload_files(upload_file_ids)
for file in files: for file in files:
file_id = uuid.UUID(file.id) model_id = uuid.UUID(file.related_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(file_id) upload_file_row = upload_files.get(model_id)
if upload_file_row is None: if upload_file_row is None:
raise ValueError(...) raise ValueError(...)
file._storage_key = upload_file_row.key file._storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE: elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(file_id) tool_file_row = tool_files.get(model_id)
if tool_file_row is None: if tool_file_row is None:
raise ValueError(...) raise ValueError(...)
file._storage_key = tool_file_row.file_key file._storage_key = tool_file_row.file_key

@ -662,7 +662,9 @@ class DraftVariableSaver:
self._node_type, self._node_type,
) )
continue continue
if isinstance(value, Segment):
value_seg = value
else:
value_seg = _build_segment_for_serialized_values(value) value_seg = _build_segment_for_serialized_values(value)
draft_vars.append( draft_vars.append(
WorkflowDraftVariable.new_node_variable( WorkflowDraftVariable.new_node_variable(

@ -460,7 +460,7 @@ class WorkflowService:
node_run_result = event.run_result node_run_result = event.run_result
# sign output files # sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break break
if not node_run_result: if not node_run_result:
@ -522,7 +522,7 @@ class WorkflowService:
if node_run_result.process_data if node_run_result.process_data
else None else None
) )
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None outputs = node_run_result.outputs
node_execution.inputs = inputs node_execution.inputs = inputs
node_execution.process_data = process_data node_execution.process_data = process_data

@ -101,26 +101,29 @@ class TestStorageKeyLoader(unittest.TestCase):
return tool_file return tool_file
def _create_file(self, file_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None) -> File: def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
"""Helper method to create a File object for testing.""" """Helper method to create a File object for testing."""
if tenant_id is None: if tenant_id is None:
tenant_id = self.tenant_id tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
related_id = None file_related_id = None
remote_url = None remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
related_id = file_id file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL: elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt" remote_url = "https://example.com/test_file.txt"
file_related_id = related_id
return File( return File(
id=file_id, id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id, tenant_id=tenant_id,
type=FileType.DOCUMENT, type=FileType.DOCUMENT,
transfer_method=transfer_method, transfer_method=transfer_method,
related_id=related_id, related_id=file_related_id,
remote_url=remote_url, remote_url=remote_url,
filename="test_file.txt", filename="test_file.txt",
extension=".txt", extension=".txt",
@ -133,7 +136,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test loading storage keys for LOCAL_FILE transfer method.""" """Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data # Create test data
upload_file = self._create_upload_file() upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Load storage keys # Load storage keys
self.loader.load_storage_keys([file]) self.loader.load_storage_keys([file])
@ -145,7 +148,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test loading storage keys for REMOTE_URL transfer method.""" """Test loading storage keys for REMOTE_URL transfer method."""
# Create test data # Create test data
upload_file = self._create_upload_file() upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.REMOTE_URL) file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
# Load storage keys # Load storage keys
self.loader.load_storage_keys([file]) self.loader.load_storage_keys([file])
@ -157,7 +160,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test loading storage keys for TOOL_FILE transfer method.""" """Test loading storage keys for TOOL_FILE transfer method."""
# Create test data # Create test data
tool_file = self._create_tool_file() tool_file = self._create_tool_file()
file = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE) file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
# Load storage keys # Load storage keys
self.loader.load_storage_keys([file]) self.loader.load_storage_keys([file])
@ -172,9 +175,9 @@ class TestStorageKeyLoader(unittest.TestCase):
upload_file2 = self._create_upload_file() upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file() tool_file = self._create_tool_file()
file1 = self._create_file(upload_file1.id, FileTransferMethod.LOCAL_FILE) file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(upload_file2.id, FileTransferMethod.REMOTE_URL) file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE) file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
files = [file1, file2, file3] files = [file1, file2, file3]
@ -195,7 +198,9 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test tenant_id validation.""" """Test tenant_id validation."""
# Create file with different tenant_id # Create file with different tenant_id
upload_file = self._create_upload_file() upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())) file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
)
# Should raise ValueError for tenant mismatch # Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context: with pytest.raises(ValueError) as context:
@ -204,12 +209,12 @@ class TestStorageKeyLoader(unittest.TestCase):
assert "invalid file, expected tenant_id" in str(context.value) assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_missing_file_id(self): def test_load_storage_keys_missing_file_id(self):
"""Test with None file.id.""" """Test with None file.related_id."""
# Create a file with valid parameters first, then manually set id to None # Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE) file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.id = None file.related_id = None
# Should raise ValueError for None file id # Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context: with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file]) self.loader.load_storage_keys([file])
@ -219,7 +224,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test with missing UploadFile database records.""" """Test with missing UploadFile database records."""
# Create file with non-existent upload file id # Create file with non-existent upload file id
non_existent_id = str(uuid4()) non_existent_id = str(uuid4())
file = self._create_file(non_existent_id, FileTransferMethod.LOCAL_FILE) file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record # Should raise ValueError for missing record
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -229,7 +234,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test with missing ToolFile database records.""" """Test with missing ToolFile database records."""
# Create file with non-existent tool file id # Create file with non-existent tool file id
non_existent_id = str(uuid4()) non_existent_id = str(uuid4())
file = self._create_file(non_existent_id, FileTransferMethod.TOOL_FILE) file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record # Should raise ValueError for missing record
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -237,9 +242,9 @@ class TestStorageKeyLoader(unittest.TestCase):
def test_load_storage_keys_invalid_uuid(self): def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format.""" """Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid id # Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE) file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.id = "invalid-uuid-format" file.related_id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID # Should raise ValueError for invalid UUID
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -252,8 +257,12 @@ class TestStorageKeyLoader(unittest.TestCase):
tool_files = [self._create_tool_file() for _ in range(2)] tool_files = [self._create_tool_file() for _ in range(2)]
files = [] files = []
files.extend([self._create_file(uf.id, FileTransferMethod.LOCAL_FILE) for uf in upload_files]) files.extend(
files.extend([self._create_file(tf.id, FileTransferMethod.TOOL_FILE) for tf in tool_files]) [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)
# Mock the session to count queries # Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
@ -275,7 +284,9 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file for current tenant # Create upload file for current tenant
upload_file_current = self._create_upload_file() upload_file_current = self._create_upload_file()
file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE) file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create upload file for other tenant (but don't add to cleanup list) # Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile( upload_file_other = UploadFile(
@ -296,7 +307,9 @@ class TestStorageKeyLoader(unittest.TestCase):
self.session.flush() self.session.flush()
# Create file for other tenant but try to load with current tenant's loader # Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(upload_file_other.id, FileTransferMethod.LOCAL_FILE, other_tenant_id) file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError due to tenant mismatch # Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context: with pytest.raises(ValueError) as context:
@ -312,11 +325,15 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test batch with mixed tenant files (should fail on first mismatch).""" """Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant # Create files for current tenant
upload_file_current = self._create_upload_file() upload_file_current = self._create_upload_file()
file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE) file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create file for different tenant # Create file for different tenant
other_tenant_id = str(uuid4()) other_tenant_id = str(uuid4())
file_other = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE, other_tenant_id) file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError on tenant mismatch # Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context: with pytest.raises(ValueError) as context:
@ -329,9 +346,9 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file # Create upload file
upload_file = self._create_upload_file() upload_file = self._create_upload_file()
# Create two File objects with same ID # Create two File objects with same related_id
file1 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should handle duplicates gracefully # Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2]) self.loader.load_storage_keys([file1, file2])
@ -344,7 +361,7 @@ class TestStorageKeyLoader(unittest.TestCase):
"""Test that the loader uses the provided session correctly.""" """Test that the loader uses the provided session correctly."""
# Create test data # Create test data
upload_file = self._create_upload_file() upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Create loader with different session (same underlying connection) # Create loader with different session (same underlying connection)

Loading…
Cancel
Save