diff --git a/api/.ruff.toml b/api/.ruff.toml index facb0d5419..0169613bf8 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,6 +1,4 @@ -exclude = [ - "migrations/*", -] +exclude = ["migrations/*"] line-length = 120 [format] @@ -9,14 +7,14 @@ quote-style = "double" [lint] preview = false select = [ - "B", # flake8-bugbear rules - "C4", # flake8-comprehensions - "E", # pycodestyle E rules - "F", # pyflakes rules - "FURB", # refurb rules - "I", # isort rules - "N", # pep8-naming - "PT", # flake8-pytest-style rules + "B", # flake8-bugbear rules + "C4", # flake8-comprehensions + "E", # pycodestyle E rules + "F", # pyflakes rules + "FURB", # refurb rules + "I", # isort rules + "N", # pep8-naming + "PT", # flake8-pytest-style rules "PLC0208", # iteration-over-set "PLC0414", # useless-import-alias "PLE0604", # invalid-all-object @@ -24,19 +22,19 @@ select = [ "PLR0402", # manual-from-import "PLR1711", # useless-return "PLR1714", # repeated-equality-comparison - "RUF013", # implicit-optional - "RUF019", # unnecessary-key-check - "RUF100", # unused-noqa - "RUF101", # redirected-noqa - "RUF200", # invalid-pyproject-toml - "RUF022", # unsorted-dunder-all - "S506", # unsafe-yaml-load - "SIM", # flake8-simplify rules - "TRY400", # error-instead-of-exception - "TRY401", # verbose-log-message - "UP", # pyupgrade rules - "W191", # tab-indentation - "W605", # invalid-escape-sequence + "RUF013", # implicit-optional + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "RUF200", # invalid-pyproject-toml + "RUF022", # unsorted-dunder-all + "S506", # unsafe-yaml-load + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "TRY401", # verbose-log-message + "UP", # pyupgrade rules + "W191", # tab-indentation + "W605", # invalid-escape-sequence # security related linting rules # RCE proctection (sort of) "S102", # exec-builtin, disallow use of `exec` @@ -47,36 +45,37 @@ select = [ ] ignore = [ - "E402", # module-import-not-at-top-of-file - "E711", # none-comparison - "E712", # true-false-comparison - "E721", # type-comparison - "E722", # bare-except - "F821", # undefined-name - "F841", # unused-variable + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "F821", # undefined-name + "F841", # unused-variable "FURB113", # repeated-append "FURB152", # math-constant - "UP007", # non-pep604-annotation - "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B026", # star-arg-unpacking-after-keyword-arg - "B903", # class-as-data-structure - "B904", # raise-without-from-inside-except - "B905", # zip-without-explicit-strict - "N806", # non-lowercase-variable-in-function - "N815", # mixed-case-variable-in-class-scope - "PT011", # pytest-raises-too-broad - "SIM102", # collapsible-if - "SIM103", # needless-bool - "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally - "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # enumerate-for-loop - "SIM117", # multiple-with-statements - "SIM210", # if-expr-with-true-false + "UP007", # non-pep604-annotation + "UP032", # f-string + "UP045", # non-pep604-annotation-optional + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B903", # class-as-data-structure + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # enumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6f524a5872..cd1d298ca2 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -125,7 +126,7 @@ class WorkflowResponseConverter: id=workflow_execution.id_, workflow_id=workflow_execution.workflow_id, status=workflow_execution.status, - outputs=workflow_execution.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), error=workflow_execution.error_message, elapsed_time=workflow_execution.elapsed_time, total_tokens=workflow_execution.total_tokens, @@ -202,6 +203,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -214,7 +217,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -245,6 +248,8 @@ class WorkflowResponseConverter: if not workflow_node_execution.finished_at: return None + json_converter = WorkflowRuntimeTypeConverter() + return NodeRetryStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_execution_id, @@ -257,7 +262,7 @@ class WorkflowResponseConverter: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs, process_data=workflow_node_execution.process_data, - outputs=workflow_node_execution.outputs, + outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, @@ -376,6 +381,7 @@ class WorkflowResponseConverter: workflow_execution_id: str, event: QueueIterationCompletedEvent, ) -> IterationNodeCompletedStreamResponse: + json_converter = WorkflowRuntimeTypeConverter() return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_execution_id, @@ -384,7 +390,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=json_converter.to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -463,7 +469,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_data.title, - outputs=event.outputs, + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), created_at=int(time.time()), extras={}, inputs=event.inputs or {}, diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e5ead9dc56..3c0ab12bde 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( WorkflowType, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -165,7 +166,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None - db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.outputs = ( + json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs)) + if domain_model.outputs + else None + ) db_model.status = domain_model.status db_model.error = domain_model.error_message if domain_model.error_message else None db_model.total_tokens = domain_model.total_tokens diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index f3a245e65b..797cce9354 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -19,7 +19,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from libs.jsonutil import PydanticModelEncoder +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import ( Account, CreatorUserRole, @@ -147,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if not self._creator_user_role: raise ValueError("created_by_role is required in repository constructor") + json_converter = WorkflowRuntimeTypeConverter() db_model = WorkflowNodeExecutionModel() db_model.id = domain_model.id db_model.tenant_id = self._tenant_id @@ -161,11 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) db_model.node_id = domain_model.node_id db_model.node_type = domain_model.node_type db_model.title = domain_model.title - db_model.inputs = json.dumps(domain_model.inputs, cls=PydanticModelEncoder) if domain_model.inputs else None + db_model.inputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None + ) db_model.process_data = ( - json.dumps(domain_model.process_data, cls=PydanticModelEncoder) if domain_model.process_data else None + json.dumps(json_converter.to_json_encodable(domain_model.process_data)) + if domain_model.process_data + else None + ) + db_model.outputs = ( + json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None ) - db_model.outputs = json.dumps(domain_model.outputs, cls=PydanticModelEncoder) if domain_model.outputs else None db_model.status = domain_model.status db_model.error = domain_model.error db_model.elapsed_time = domain_model.elapsed_time diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 4387e9693e..68d3d82883 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -18,3 +18,17 @@ class SegmentType(StrEnum): NONE = "none" GROUP = "group" + + def is_array_type(self): + return self in _ARRAY_TYPES + + +_ARRAY_TYPES = frozenset( + [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] +) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index d7e36aa93e..38c2bcbdf5 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -49,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"answer": answer, "files": ArrayFileSegment(value=files)}, + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index fccb13360c..22ed9e2651 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -130,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]): prefix: str = "", depth: int = 1, ): + # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. + # Note that `_transform_result` may produce lists containing `None` values, + # which don't conform to the type requirements of `Array*Segment` classes. if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 4fc1ed8d9b..9f48b48865 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -24,7 +24,7 @@ from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment -from core.variables.segments import FileSegment +from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -71,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"text": extracted_text_list}, + outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): extracted_text = _extract_text_from_file(value) diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index aa494fdb53..5059e1f191 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -6,6 +6,7 @@ from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -170,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): return mapping - def extract_files(self, url: str, response: Response) -> list[File]: + def extract_files(self, url: str, response: Response) -> ArrayFileSegment: """ Extract files from response by checking both Content-Type header and URL """ @@ -182,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): content_disposition_type = None if not is_file: - return files + return ArrayFileSegment(value=[]) if parsed_content_disposition: content_disposition_filename = parsed_content_disposition.get_filename() @@ -215,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): ) files.append(file) - return files + return ArrayFileSegment(value=files) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5243a54bc4..151efc28ec 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,6 +11,7 @@ from flask import Flask, current_app from configs import dify_config from core.variables import ArrayVariable, IntegerVariable, NoneVariable +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import ( NodeRunResult, ) @@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from factories.variable_factory import build_segment from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -89,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") if isinstance(variable, NoneVariable) or len(variable.value) == 0: + # Try our best to preserve the type informat. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": []}, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, ) ) return @@ -235,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]): # Flatten the list of lists if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): outputs = [item for sublist in outputs for item in sublist] + output_segment = build_segment(outputs) yield IterationRunSucceededEvent( iteration_id=self.id, @@ -251,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": output_segment}, metadata={ WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5cf5848d54..2995f0682f 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment +from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType @@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=node_data, query=query) - outputs = {"result": results} + outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs, # type: ignore ) except KnowledgeRetrievalNodeError as e: diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 56e8b20086..3c9ba44cf1 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -34,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): if not variable.value: inputs = {"variable": []} process_data = {"variable": []} - outputs = {"result": [], "first_record": None, "last_record": None} + if isinstance(variable, ArraySegment): + result = variable.model_copy(update={"value": []}) + else: + result = ArrayAnySegment(value=[]) + outputs = {"result": result, "first_record": None, "last_record": None} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -75,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): variable = self._apply_slice(variable) outputs = { - "result": variable.value, + "result": variable, "first_record": variable.value[0] if variable.value else None, "last_record": variable.value[-1] if variable.value else None, } diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c4c9546920..124ae6d75d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -259,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]): if structured_output: outputs["structured_output"] = structured_output if self._file_outputs is not None: - outputs["files"] = self._file_outputs + outputs["files"] = ArrayFileSegment(value=self._file_outputs) yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 369eb13b04..916778d167 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.llm import ModelConfig, VisionConfig +class _ParameterConfigError(Exception): + pass + + class ParameterConfig(BaseModel): """ Parameter Config. @@ -27,6 +31,19 @@ class ParameterConfig(BaseModel): raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return str(value) + def is_array_type(self) -> bool: + return self.type in ("array[string]", "array[number]", "array[object]") + + def element_type(self) -> Literal["string", "number", "object"]: + if self.type == "array[number]": + return "number" + elif self.type == "array[string]": + return "string" + elif self.type == "array[object]": + return "object" + else: + raise _ParameterConfigError(f"{self.type} is not array type.") + class ParameterExtractorNodeData(BaseNodeData): """ diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index bde319ebe2..8d6c2d0a5c 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser +from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( @@ -588,28 +590,30 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith("array"): + elif parameter.is_array_type(): if isinstance(result[parameter.name], list): - nested_type = parameter.type[6:-1] - transformed_result[parameter.name] = [] + nested_type = parameter.element_type() + assert nested_type is not None + segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) + transformed_result[parameter.name] = segment_value for item in result[parameter.name]: if nested_type == "number": if isinstance(item, int | float): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif isinstance(item, str): try: if "." in item: - transformed_result[parameter.name].append(float(item)) + segment_value.value.append(float(item)) else: - transformed_result[parameter.name].append(int(item)) + segment_value.value.append(int(item)) except ValueError: pass elif nested_type == "string": if isinstance(item, str): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) elif nested_type == "object": if isinstance(item, dict): - transformed_result[parameter.name].append(item) + segment_value.value.append(item) if parameter.name not in transformed_result: if parameter.type == "number": @@ -619,7 +623,9 @@ class ParameterExtractorNode(BaseNode): elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): - transformed_result[parameter.name] = [] + transformed_result[parameter.name] = build_segment_with_type( + segment_type=SegmentType(parameter.type), value=[] + ) return transformed_result diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 608ec5e004..aa15d69931 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -304,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]): variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None + assert isinstance(message.meta, File) files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -367,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 4b1f816cad..167805c6ff 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -22,7 +22,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): for selector in self.node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = {"output": variable.to_object()} + outputs = {"output": variable} inputs = {".".join(selector[1:]): variable.to_object()} break @@ -32,7 +32,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = {"output": variable.to_object()} + outputs[group.group_name] = {"output": variable} inputs[".".join(selector[1:])] = variable.to_object() break diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index b88f9edd03..6ee562fc8d 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -92,7 +92,7 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(outputs) + # outputs = WorkflowEntry.handle_special_values(outputs) workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED workflow_execution.outputs = outputs or {} @@ -125,7 +125,7 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) + # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED execution.outputs = outputs or {} @@ -242,9 +242,9 @@ class WorkflowCycleManager: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + inputs = event.inputs + process_data = event.process_data + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -289,7 +289,7 @@ class WorkflowCycleManager: # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings execution_metadata_dict = {} @@ -326,7 +326,7 @@ class WorkflowCycleManager: finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = WorkflowEntry.handle_special_values(event.outputs) + outputs = event.outputs # Convert metadata keys to strings origin_metadata = { diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ddf0620077..182c54fa77 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -190,6 +190,13 @@ class WorkflowEntry: # run node generator = node_instance.run() except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + workflow.id, + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) return node_instance, generator @@ -292,6 +299,12 @@ class WorkflowEntry: return node_instance, generator except Exception as e: + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @staticmethod diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py new file mode 100644 index 0000000000..0123fdac18 --- /dev/null +++ b/api/core/workflow/workflow_type_encoder.py @@ -0,0 +1,49 @@ +import json +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel + +from core.file.models import File +from core.variables import Segment + + +class WorkflowRuntimeTypeEncoder(json.JSONEncoder): + def default(self, o: Any): + if isinstance(o, Segment): + return o.value + elif isinstance(o, File): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump(mode="json") + else: + return super().default(o) + + +class WorkflowRuntimeTypeConverter: + def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + result = self._to_json_encodable_recursive(value) + return result if isinstance(result, Mapping) or result is None else dict(result) + + def _to_json_encodable_recursive(self, value: Any) -> Any: + if value is None: + return value + if isinstance(value, (bool, int, str, float)): + return value + if isinstance(value, Segment): + return self._to_json_encodable_recursive(value.value) + if isinstance(value, File): + return value.to_dict() + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = self._to_json_encodable_recursive(v) + return res + if isinstance(value, list): + res_list = [] + for item in value: + res_list.append(self._to_json_encodable_recursive(item)) + return res_list + return value diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 5de3e45ef7..e0beef40c6 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -422,32 +422,33 @@ class StorageKeyLoader: upload_file_ids: list[uuid.UUID] = [] tool_file_ids: list[uuid.UUID] = [] for file in files: - if file.id is None: + related_model_id = file.related_id + if file.related_id is None: raise ValueError("file id should not be None.") if file.tenant_id != self._tenant_id: err_msg = ( f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}" + f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" ) raise ValueError(err_msg) - file_id = uuid.UUID(file.id) + model_id = uuid.UUID(related_model_id) if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(file_id) + upload_file_ids.append(model_id) elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(file_id) + tool_file_ids.append(model_id) tool_files = self._load_tool_files(tool_file_ids) upload_files = self._load_upload_files(upload_file_ids) for file in files: - file_id = uuid.UUID(file.id) + model_id = uuid.UUID(file.related_id) if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(file_id) + upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(...) file._storage_key = upload_file_row.key elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(file_id) + tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(...) file._storage_key = tool_file_row.file_key diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 5d3011cbed..cd30440b4f 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -662,8 +662,10 @@ class DraftVariableSaver: self._node_type, ) continue - - value_seg = _build_segment_for_serialized_values(value) + if isinstance(value, Segment): + value_seg = value + else: + value_seg = _build_segment_for_serialized_values(value) draft_vars.append( WorkflowDraftVariable.new_node_variable( app_id=self._app_id, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 53a22c8e76..d52e4302ef 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -460,7 +460,7 @@ class WorkflowService: node_run_result = event.run_result # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) break if not node_run_result: @@ -522,7 +522,7 @@ class WorkflowService: if node_run_result.process_data else None ) - outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + outputs = node_run_result.outputs node_execution.inputs = inputs node_execution.process_data = process_data diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index e98d2be9ec..fecb3f6d95 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -101,26 +101,29 @@ class TestStorageKeyLoader(unittest.TestCase): return tool_file - def _create_file(self, file_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None) -> File: + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: """Helper method to create a File object for testing.""" if tenant_id is None: tenant_id = self.tenant_id # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods - related_id = None + file_related_id = None remote_url = None if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): - related_id = file_id + file_related_id = related_id elif transfer_method == FileTransferMethod.REMOTE_URL: remote_url = "https://example.com/test_file.txt" + file_related_id = related_id return File( - id=file_id, + id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, type=FileType.DOCUMENT, transfer_method=transfer_method, - related_id=related_id, + related_id=file_related_id, remote_url=remote_url, filename="test_file.txt", extension=".txt", @@ -133,7 +136,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test loading storage keys for LOCAL_FILE transfer method.""" # Create test data upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) # Load storage keys self.loader.load_storage_keys([file]) @@ -145,7 +148,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test loading storage keys for REMOTE_URL transfer method.""" # Create test data upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.REMOTE_URL) + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) # Load storage keys self.loader.load_storage_keys([file]) @@ -157,7 +160,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test loading storage keys for TOOL_FILE transfer method.""" # Create test data tool_file = self._create_tool_file() - file = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE) + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) # Load storage keys self.loader.load_storage_keys([file]) @@ -172,9 +175,9 @@ class TestStorageKeyLoader(unittest.TestCase): upload_file2 = self._create_upload_file() tool_file = self._create_tool_file() - file1 = self._create_file(upload_file1.id, FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(upload_file2.id, FileTransferMethod.REMOTE_URL) - file3 = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE) + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) files = [file1, file2, file3] @@ -195,7 +198,9 @@ class TestStorageKeyLoader(unittest.TestCase): """Test tenant_id validation.""" # Create file with different tenant_id upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())) + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) # Should raise ValueError for tenant mismatch with pytest.raises(ValueError) as context: @@ -204,12 +209,12 @@ class TestStorageKeyLoader(unittest.TestCase): assert "invalid file, expected tenant_id" in str(context.value) def test_load_storage_keys_missing_file_id(self): - """Test with None file.id.""" - # Create a file with valid parameters first, then manually set id to None - file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE) - file.id = None + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None - # Should raise ValueError for None file id + # Should raise ValueError for None file related_id with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file]) @@ -219,7 +224,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test with missing UploadFile database records.""" # Create file with non-existent upload file id non_existent_id = str(uuid4()) - file = self._create_file(non_existent_id, FileTransferMethod.LOCAL_FILE) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) # Should raise ValueError for missing record with pytest.raises(ValueError): @@ -229,7 +234,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test with missing ToolFile database records.""" # Create file with non-existent tool file id non_existent_id = str(uuid4()) - file = self._create_file(non_existent_id, FileTransferMethod.TOOL_FILE) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) # Should raise ValueError for missing record with pytest.raises(ValueError): @@ -237,9 +242,9 @@ class TestStorageKeyLoader(unittest.TestCase): def test_load_storage_keys_invalid_uuid(self): """Test with invalid UUID format.""" - # Create a file with valid parameters first, then manually set invalid id - file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE) - file.id = "invalid-uuid-format" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" # Should raise ValueError for invalid UUID with pytest.raises(ValueError): @@ -252,8 +257,12 @@ class TestStorageKeyLoader(unittest.TestCase): tool_files = [self._create_tool_file() for _ in range(2)] files = [] - files.extend([self._create_file(uf.id, FileTransferMethod.LOCAL_FILE) for uf in upload_files]) - files.extend([self._create_file(tf.id, FileTransferMethod.TOOL_FILE) for tf in tool_files]) + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) # Mock the session to count queries with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: @@ -275,7 +284,9 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file for current tenant upload_file_current = self._create_upload_file() - file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE) + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) # Create upload file for other tenant (but don't add to cleanup list) upload_file_other = UploadFile( @@ -296,7 +307,9 @@ class TestStorageKeyLoader(unittest.TestCase): self.session.flush() # Create file for other tenant but try to load with current tenant's loader - file_other = self._create_file(upload_file_other.id, FileTransferMethod.LOCAL_FILE, other_tenant_id) + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) # Should raise ValueError due to tenant mismatch with pytest.raises(ValueError) as context: @@ -312,11 +325,15 @@ class TestStorageKeyLoader(unittest.TestCase): """Test batch with mixed tenant files (should fail on first mismatch).""" # Create files for current tenant upload_file_current = self._create_upload_file() - file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE) + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) # Create file for different tenant other_tenant_id = str(uuid4()) - file_other = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE, other_tenant_id) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) # Should raise ValueError on tenant mismatch with pytest.raises(ValueError) as context: @@ -329,9 +346,9 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file upload_file = self._create_upload_file() - # Create two File objects with same ID - file1 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) # Should handle duplicates gracefully self.loader.load_storage_keys([file1, file2]) @@ -344,7 +361,7 @@ class TestStorageKeyLoader(unittest.TestCase): """Test that the loader uses the provided session correctly.""" # Create test data upload_file = self._create_upload_file() - file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE) + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) # Create loader with different session (same underlying connection)