diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 85ff4f9c05..1ef024f46b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -58,21 +58,26 @@ def test_execute_answer(): pool.add(["start", "weather"], "sunny") pool.add(["llm", "text"], "You are a helpful AI.") + node_config = { + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + } + node = AnswerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 33f9251a72..71b3a8f7d8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch): ), ), ) + + node_config = { + "id": "1", + "data": data.model_dump(), + } + node = HttpRequestNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch): start_at=0, ), ) + + # Initialize node data + node.init_node_data(node_config["data"]) monkeypatch.setattr( "core.workflow.nodes.http_request.executor.file_manager.download", lambda *args, **kwargs: b"test", @@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch): ), ), ) + + node_config = { + "id": "1", + "data": data.model_dump(), + } + node = HttpRequestNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch): start_at=0, ), ) + + # Initialize node data + node.init_node_data(node_config["data"]) + monkeypatch.setattr( "core.workflow.nodes.http_request.executor.file_manager.download", lambda *args, **kwargs: b"test", @@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch): ), ) + node_config = { + "id": "1", + "data": data.model_dump(), + } + node = HttpRequestNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch): ), ) + # Initialize node data + node.init_node_data(node_config["data"]) + monkeypatch.setattr( "core.workflow.nodes.http_request.executor.file_manager.download", lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 17c23b7735..787d4cb3ee 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -162,25 +162,30 @@ def test_run(): ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + } + iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "tt", - "title": "迭代", - "type": "iteration", - }, - "id": "iteration-1", - }, + config=node_config, ) + # Initialize node data + iteration_node.init_node_data(node_config["data"]) + def tt_generator(self): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -379,25 +384,30 @@ def test_run_parallel(): ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + } + iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - }, - "id": "iteration-1", - }, + config=node_config, ) + # Initialize node data + iteration_node.init_node_data(node_config["data"]) + def tt_generator(self): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode(): ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + parallel_node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + } + parallel_iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - "is_parallel": True, - }, - "id": "iteration-1", - }, + config=parallel_node_config, ) + + # Initialize node data + parallel_iteration_node.init_node_data(parallel_node_config["data"]) + sequential_node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + } + sequential_iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - "is_parallel": True, - }, - "id": "iteration-1", - }, + config=sequential_node_config, ) + # Initialize node data + sequential_iteration_node.init_node_data(sequential_node_config["data"]) + def tt_generator(self): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -818,26 +838,31 @@ def test_iteration_run_error_handle(): environment_variables=[], ) pool.add(["pe", "list_output"], ["1", "1"]) + error_node_config = { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + } + iteration_node = IterationNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "iteration", - "type": "iteration", - "is_parallel": True, - "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, - }, - "id": "iteration-1", - }, + config=error_node_config, ) + + # Initialize node data + iteration_node.init_node_data(error_node_config["data"]) # execute continue on error node result = iteration_node._run() result_arr = [] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 42f3bffab1..23a7fab7cf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -119,17 +119,20 @@ def llm_node( llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState ) -> LLMNode: mock_file_saver = mock.MagicMock(spec=LLMFileSaver) + node_config = { + "id": "1", + "data": llm_node_data.model_dump(), + } node = LLMNode( id="1", - config={ - "id": "1", - "data": llm_node_data.model_dump(), - }, + config=node_config, graph_init_params=graph_init_params, graph=graph, graph_runtime_state=graph_runtime_state, llm_file_saver=mock_file_saver, ) + # Initialize node data + node.init_node_data(node_config["data"]) return node @@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node): variable_pool = llm_node.graph_runtime_state.variable_pool vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH - result = llm_node._handle_list_messages( + result = llm_node.handle_list_messages( messages=messages, context=context, jinja2_variables=jinja2_variables, @@ -506,17 +509,20 @@ def llm_node_for_multimodal( llm_node_data, graph_init_params, graph, graph_runtime_state ) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) + node_config = { + "id": "1", + "data": llm_node_data.model_dump(), + } node = LLMNode( id="1", - config={ - "id": "1", - "data": llm_node_data.model_dump(), - }, + config=node_config, graph_init_params=graph_init_params, graph=graph, graph_runtime_state=graph_runtime_state, llm_file_saver=mock_file_saver, ) + # Initialize node data + node.init_node_data(node_config["data"]) return node, mock_file_saver @@ -544,6 +550,8 @@ class TestLLMNodeSaveMultiModalImageOutput: content=content, file_saver=mock_file_saver, ) + # Manually append to _file_outputs since the static method doesn't do it + llm_node._file_outputs.append(file) assert llm_node._file_outputs == [mock_file] assert file == mock_file mock_file_saver.save_binary_string.assert_called_once_with( @@ -573,6 +581,8 @@ class TestLLMNodeSaveMultiModalImageOutput: content=content, file_saver=mock_file_saver, ) + # Manually append to _file_outputs since the static method doesn't do it + llm_node._file_outputs.append(file) assert llm_node._file_outputs == [mock_file] assert file == mock_file mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) @@ -588,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode): class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_str_content(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents="hello world", file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == ["hello world"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() @@ -596,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_text_prompt_message_content(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( - [TextPromptMessageContent(data="hello world")] + contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[] ) assert list(gen) == ["hello world"] mock_file_saver.save_binary_string.assert_not_called() @@ -622,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: ) mock_file_saver.save_binary_string.return_value = mock_saved_file gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( - [ + contents=[ ImagePromptMessageContent( format="png", base64_data=image_b64_data, mime_type="image/png", ) - ] + ], + file_saver=mock_file_saver, + file_outputs=llm_node._file_outputs, ) yielded_strs = list(gen) assert len(yielded_strs) == 1 @@ -651,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_unknown_content_type(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == ["frozenset({'hello world'})"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() def test_unknown_item_type(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == ["frozenset({'hello world'})"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() def test_none_content(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal - gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + contents=None, file_saver=mock_file_saver, file_outputs=[] + ) assert list(gen) == [] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 44c31b212e..466d7bad06 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -61,21 +61,26 @@ def test_execute_answer(): variable_pool.add(["start", "weather"], "sunny") variable_pool.add(["llm", "text"], "You are a helpful AI.") + node_config = { + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + } + node = AnswerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 66c7818adf..486ae51e5f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -27,13 +27,17 @@ def document_extractor_node(): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - return DocumentExtractorNode( + node_config = {"id": "test_node_id", "data": node_data.model_dump()} + node = DocumentExtractorNode( id="test_node_id", - config={"id": "test_node_id", "data": node_data.model_dump()}, + config=node_config, graph_init_params=Mock(), graph=Mock(), graph_runtime_state=Mock(), ) + # Initialize node data + node.init_node_data(node_config["data"]) + return node @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 167a92484d..8383aee0e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -57,57 +57,62 @@ def test_execute_if_else_result_true(): pool.add(["start", "null"], None) pool.add(["start", "not_null"], "1212") + node_config = { + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", + }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, + { + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", + }, + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + } + node = IfElseNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "id": "if-else", - "data": { - "title": "123", - "type": "if-else", - "logical_operator": "and", - "conditions": [ - { - "comparison_operator": "contains", - "variable_selector": ["start", "array_contains"], - "value": "ab", - }, - { - "comparison_operator": "not contains", - "variable_selector": ["start", "array_not_contains"], - "value": "ab", - }, - {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, - { - "comparison_operator": "not contains", - "variable_selector": ["start", "not_contains"], - "value": "ab", - }, - {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, - {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, - {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, - {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, - {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, - {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, - {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, - {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, - {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, - {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, - { - "comparison_operator": "≥", - "variable_selector": ["start", "greater_than_or_equal"], - "value": "22", - }, - {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, - {"comparison_operator": "null", "variable_selector": ["start", "null"]}, - {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() @@ -162,33 +167,38 @@ def test_execute_if_else_result_false(): pool.add(["start", "array_contains"], ["1ab", "def"]) pool.add(["start", "array_not_contains"], ["ab", "def"]) + node_config = { + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + } + node = IfElseNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), - config={ - "id": "if-else", - "data": { - "title": "123", - "type": "if-else", - "logical_operator": "or", - "conditions": [ - { - "comparison_operator": "contains", - "variable_selector": ["start", "array_contains"], - "value": "ab", - }, - { - "comparison_operator": "not contains", - "variable_selector": ["start", "array_not_contains"], - "value": "ab", - }, - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Mock db.session.close() db.session.close = MagicMock() @@ -228,17 +238,22 @@ def test_array_file_contains_file_name(): ], ) + node_config = { + "id": "if-else", + "data": node_data.model_dump(), + } + node = IfElseNode( id=str(uuid.uuid4()), graph_init_params=Mock(), graph=Mock(), graph_runtime_state=Mock(), - config={ - "id": "if-else", - "data": node_data.model_dump(), - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 7d3a1d6a2d..5fc9eab2df 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -33,16 +33,19 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData(**config) + node_config = { + "id": "test_node_id", + "data": node_data.model_dump(), + } node = ListOperatorNode( id="test_node_id", - config={ - "id": "test_node_id", - "data": node_data.model_dump(), - }, + config=node_config, graph_init_params=MagicMock(), graph=MagicMock(), graph_runtime_state=MagicMock(), ) + # Initialize node data + node.init_node_data(node_config["data"]) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 2776e57777..0eaabd0c40 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -38,12 +38,13 @@ def _create_tool_node(): system_variables=SystemVariable.empty(), user_inputs={}, ) + node_config = { + "id": "1", + "data": data.model_dump(), + } node = ToolNode( id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, + config=node_config, graph_init_params=GraphInitParams( tenant_id="1", app_id="1", @@ -71,6 +72,8 @@ def _create_tool_node(): start_at=0, ), ) + # Initialize node data + node.init_node_data(node_config["data"]) return node diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 62e3e37104..ee51339427 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -82,23 +82,28 @@ def test_overwrite_string_variable(): mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, + config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) + # Initialize node data + node.init_node_data(node_config["data"]) + list(node.run()) expected_var = StringVariable( id=conversation_variable.id, @@ -178,23 +183,28 @@ def test_append_variable_to_array(): mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - }, + config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) + # Initialize node data + node.init_node_data(node_config["data"]) + list(node.run()) expected_value = list(conversation_variable.value) expected_value.append(input_variable.value) @@ -265,23 +275,28 @@ def test_clear_array(): mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR.value, - "input_variable_selector": [], - }, - }, + config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) + # Initialize node data + node.init_node_data(node_config["data"]) + list(node.run()) expected_var = ArrayStringVariable( id=conversation_variable.id, diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index a3a90b0599..987eaf7534 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -115,28 +115,33 @@ def test_remove_first_from_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_FIRST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment # Print the variable before running print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") @@ -202,28 +207,33 @@ def test_remove_last_from_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_LAST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment list(node.run()) @@ -281,28 +291,33 @@ def test_remove_first_from_empty_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_FIRST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment list(node.run()) @@ -360,28 +375,33 @@ def test_remove_last_from_empty_array(): conversation_variables=[conversation_variable], ) + node_config = { + "id": "node_id", + "data": { + "title": "test", + "version": "2", + "items": [ + { + "variable_selector": ["conversation", conversation_variable.name], + "input_type": InputType.VARIABLE, + "operation": Operation.REMOVE_LAST, + "value": None, + } + ], + }, + } + node = VariableAssignerNode( id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), - config={ - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - }, + config=node_config, ) + # Initialize node data + node.init_node_data(node_config["data"]) + # Skip the mock assertion since we're in a test environment list(node.run())