diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index e89e03ae86..0df8e8b146 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -353,7 +353,7 @@ def test_extract_json_from_tool_call(): assert result["location"] == "kawaii" -def test_chat_parameter_extractor_with_memory(setup_model_mock): +def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): """ Test chat parameter extractor with memory. """ @@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) - node._fetch_memory = get_mocked_fetch_memory("customized memory") + # Test the mock before running the actual test + monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory")) db.session.close = MagicMock() result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 519dd73787..336c2befcc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -170,7 +171,7 @@ def model_config(): ) -def test_fetch_files_with_file_segment(llm_node): +def test_fetch_files_with_file_segment(): file = File( id="1", tenant_id="test", @@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node): related_id="1", storage_key="", ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], file) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [file] -def test_fetch_files_with_array_file_segment(llm_node): +def test_fetch_files_with_array_file_segment(): files = [ File( id="1", @@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node): storage_key="", ), ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == files -def test_fetch_files_with_none_segment(llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) +def test_fetch_files_with_none_segment(): + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], NoneSegment()) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [] -def test_fetch_files_with_array_any_segment(llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) +def test_fetch_files_with_array_any_segment(): + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [] -def test_fetch_files_with_non_existent_variable(llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) +def test_fetch_files_with_non_existent_variable(): + variable_pool = VariablePool() + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == []