test(nodes): Fix nodes tests

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent cbc501875a
commit b538b4f79a
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -58,21 +58,26 @@ def test_execute_answer():
pool.add(["start", "weather"], "sunny") pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.") pool.add(["llm", "text"], "You are a helpful AI.")
node = AnswerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "answer", "id": "answer",
"data": { "data": {
"title": "123", "title": "123",
"type": "answer", "type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
}, },
}, }
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()

@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
), ),
), ),
) )
node = HttpRequestNode(
id="1", node_config = {
config={
"id": "1", "id": "1",
"data": data.model_dump(), "data": data.model_dump(),
}, }
node = HttpRequestNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0, start_at=0,
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr( monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download", "core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test", lambda *args, **kwargs: b"test",
@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
), ),
), ),
) )
node = HttpRequestNode(
id="1", node_config = {
config={
"id": "1", "id": "1",
"data": data.model_dump(), "data": data.model_dump(),
}, }
node = HttpRequestNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0, start_at=0,
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr( monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download", "core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test", lambda *args, **kwargs: b"test",
@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
), ),
) )
node = HttpRequestNode( node_config = {
id="1",
config={
"id": "1", "id": "1",
"data": data.model_dump(), "data": data.model_dump(),
}, }
node = HttpRequestNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr( monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download", "core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",

@ -162,12 +162,7 @@ def test_run():
) )
pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
iteration_node = IterationNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": { "data": {
"iterator_selector": ["pe", "list_output"], "iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"], "output_selector": ["tt", "output"],
@ -178,9 +173,19 @@ def test_run():
"type": "iteration", "type": "iteration",
}, },
"id": "iteration-1", "id": "iteration-1",
}, }
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self): def tt_generator(self):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -379,12 +384,7 @@ def test_run_parallel():
) )
pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
iteration_node = IterationNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": { "data": {
"iterator_selector": ["pe", "list_output"], "iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"], "output_selector": ["tt", "output"],
@ -395,9 +395,19 @@ def test_run_parallel():
"type": "iteration", "type": "iteration",
}, },
"id": "iteration-1", "id": "iteration-1",
}, }
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self): def tt_generator(self):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -595,12 +605,7 @@ def test_iteration_run_in_parallel_mode():
) )
pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
parallel_iteration_node = IterationNode( parallel_node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": { "data": {
"iterator_selector": ["pe", "list_output"], "iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"], "output_selector": ["tt", "output"],
@ -612,14 +617,19 @@ def test_iteration_run_in_parallel_mode():
"is_parallel": True, "is_parallel": True,
}, },
"id": "iteration-1", "id": "iteration-1",
}, }
)
sequential_iteration_node = IterationNode( parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=parallel_node_config,
)
# Initialize node data
parallel_iteration_node.init_node_data(parallel_node_config["data"])
sequential_node_config = {
"data": { "data": {
"iterator_selector": ["pe", "list_output"], "iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"], "output_selector": ["tt", "output"],
@ -631,9 +641,19 @@ def test_iteration_run_in_parallel_mode():
"is_parallel": True, "is_parallel": True,
}, },
"id": "iteration-1", "id": "iteration-1",
}, }
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=sequential_node_config,
) )
# Initialize node data
sequential_iteration_node.init_node_data(sequential_node_config["data"])
def tt_generator(self): def tt_generator(self):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -818,12 +838,7 @@ def test_iteration_run_error_handle():
environment_variables=[], environment_variables=[],
) )
pool.add(["pe", "list_output"], ["1", "1"]) pool.add(["pe", "list_output"], ["1", "1"])
iteration_node = IterationNode( error_node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": { "data": {
"iterator_selector": ["pe", "list_output"], "iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"], "output_selector": ["tt", "output"],
@ -836,8 +851,18 @@ def test_iteration_run_error_handle():
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
}, },
"id": "iteration-1", "id": "iteration-1",
}, }
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=error_node_config,
) )
# Initialize node data
iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node # execute continue on error node
result = iteration_node._run() result = iteration_node._run()
result_arr = [] result_arr = []

@ -119,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode: ) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode( node_config = {
id="1",
config={
"id": "1", "id": "1",
"data": llm_node_data.model_dump(), "data": llm_node_data.model_dump(),
}, }
node = LLMNode(
id="1",
config=node_config,
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver, llm_file_saver=mock_file_saver,
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node return node
@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages( result = llm_node.handle_list_messages(
messages=messages, messages=messages,
context=context, context=context,
jinja2_variables=jinja2_variables, jinja2_variables=jinja2_variables,
@ -506,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]: ) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode( node_config = {
id="1",
config={
"id": "1", "id": "1",
"data": llm_node_data.model_dump(), "data": llm_node_data.model_dump(),
}, }
node = LLMNode(
id="1",
config=node_config,
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver, llm_file_saver=mock_file_saver,
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver return node, mock_file_saver
@ -544,6 +550,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
content=content, content=content,
file_saver=mock_file_saver, file_saver=mock_file_saver,
) )
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with( mock_file_saver.save_binary_string.assert_called_once_with(
@ -573,6 +581,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
content=content, content=content,
file_saver=mock_file_saver, file_saver=mock_file_saver,
) )
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@ -588,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown: class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal): def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
@ -596,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal): def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")] contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
) )
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
@ -622,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
) )
mock_file_saver.save_binary_string.return_value = mock_saved_file mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[ contents=[
ImagePromptMessageContent( ImagePromptMessageContent(
format="png", format="png",
base64_data=image_b64_data, base64_data=image_b64_data,
mime_type="image/png", mime_type="image/png",
) )
] ],
file_saver=mock_file_saver,
file_outputs=llm_node._file_outputs,
) )
yielded_strs = list(gen) yielded_strs = list(gen)
assert len(yielded_strs) == 1 assert len(yielded_strs) == 1
@ -651,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal): def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal): def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal): def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=None, file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == [] assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()

@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny") variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.") variable_pool.add(["llm", "text"], "You are a helpful AI.")
node = AnswerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "answer", "id": "answer",
"data": { "data": {
"title": "123", "title": "123",
"type": "answer", "type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
}, },
}, }
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()

@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor", title="Test Document Extractor",
variable_selector=["node_id", "variable_name"], variable_selector=["node_id", "variable_name"],
) )
return DocumentExtractorNode( node_config = {"id": "test_node_id", "data": node_data.model_dump()}
node = DocumentExtractorNode(
id="test_node_id", id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()}, config=node_config,
graph_init_params=Mock(), graph_init_params=Mock(),
graph=Mock(), graph=Mock(),
graph_runtime_state=Mock(), graph_runtime_state=Mock(),
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node
@pytest.fixture @pytest.fixture

@ -57,12 +57,7 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None) pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212") pool.add(["start", "not_null"], "1212")
node = IfElseNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else", "id": "if-else",
"data": { "data": {
"title": "123", "title": "123",
@ -105,9 +100,19 @@ def test_execute_if_else_result_true():
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
], ],
}, },
}, }
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
@ -162,12 +167,7 @@ def test_execute_if_else_result_false():
pool.add(["start", "array_contains"], ["1ab", "def"]) pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ab", "def"])
node = IfElseNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else", "id": "if-else",
"data": { "data": {
"title": "123", "title": "123",
@ -186,9 +186,19 @@ def test_execute_if_else_result_false():
}, },
], ],
}, },
}, }
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
], ],
) )
node_config = {
"id": "if-else",
"data": node_data.model_dump(),
}
node = IfElseNode( node = IfElseNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=Mock(), graph_init_params=Mock(),
graph=Mock(), graph=Mock(),
graph_runtime_state=Mock(), graph_runtime_state=Mock(),
config={ config=node_config,
"id": "if-else",
"data": node_data.model_dump(),
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[ value=[
File( File(

@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title", "title": "Test Title",
} }
node_data = ListOperatorNodeData(**config) node_data = ListOperatorNodeData(**config)
node = ListOperatorNode( node_config = {
id="test_node_id",
config={
"id": "test_node_id", "id": "test_node_id",
"data": node_data.model_dump(), "data": node_data.model_dump(),
}, }
node = ListOperatorNode(
id="test_node_id",
config=node_config,
graph_init_params=MagicMock(), graph_init_params=MagicMock(),
graph=MagicMock(), graph=MagicMock(),
graph_runtime_state=MagicMock(), graph_runtime_state=MagicMock(),
) )
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock() node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock() node.graph_runtime_state.variable_pool = MagicMock()
return node return node

@ -38,12 +38,13 @@ def _create_tool_node():
system_variables=SystemVariable.empty(), system_variables=SystemVariable.empty(),
user_inputs={}, user_inputs={},
) )
node = ToolNode( node_config = {
id="1",
config={
"id": "1", "id": "1",
"data": data.model_dump(), "data": data.model_dump(),
}, }
node = ToolNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -71,6 +72,8 @@ def _create_tool_node():
start_at=0, start_at=0,
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node return node

@ -82,12 +82,7 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -95,10 +90,20 @@ def test_overwrite_string_variable():
"write_mode": WriteMode.OVER_WRITE.value, "write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run()) list(node.run())
expected_var = StringVariable( expected_var = StringVariable(
id=conversation_variable.id, id=conversation_variable.id,
@ -178,12 +183,7 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -191,10 +191,20 @@ def test_append_variable_to_array():
"write_mode": WriteMode.APPEND.value, "write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run()) list(node.run())
expected_value = list(conversation_variable.value) expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value) expected_value.append(input_variable.value)
@ -265,12 +275,7 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -278,10 +283,20 @@ def test_clear_array():
"write_mode": WriteMode.CLEAR.value, "write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [], "input_variable_selector": [],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run()) list(node.run())
expected_var = ArrayStringVariable( expected_var = ArrayStringVariable(
id=conversation_variable.id, id=conversation_variable.id,

@ -115,12 +115,7 @@ def test_remove_first_from_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -134,9 +129,19 @@ def test_remove_first_from_array():
} }
], ],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
# Print the variable before running # Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@ -202,12 +207,7 @@ def test_remove_last_from_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -221,9 +221,19 @@ def test_remove_last_from_array():
} }
], ],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
list(node.run()) list(node.run())
@ -281,12 +291,7 @@ def test_remove_first_from_empty_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -300,9 +305,19 @@ def test_remove_first_from_empty_array():
} }
], ],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
list(node.run()) list(node.run())
@ -360,12 +375,7 @@ def test_remove_last_from_empty_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node = VariableAssignerNode( node_config = {
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id", "id": "node_id",
"data": { "data": {
"title": "test", "title": "test",
@ -379,9 +389,19 @@ def test_remove_last_from_empty_array():
} }
], ],
}, },
}, }
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=node_config,
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
list(node.run()) list(node.run())

Loading…
Cancel
Save