test(api): fix correspond unit tests

pull/20843/head
QuantumGhost 12 months ago
parent 1aa072af36
commit d070a06dd8

@ -353,7 +353,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii" assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock): def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
""" """
Test chat parameter extractor with memory. Test chat parameter extractor with memory.
""" """
@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
mode="chat", mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
) )
node._fetch_memory = get_mocked_fetch_memory("customized memory") # Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock() db.session.close = MagicMock()
result = node._run() result = node._run()

@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ( from core.workflow.nodes.llm.entities import (
ContextConfig, ContextConfig,
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
@ -170,7 +171,7 @@ def model_config():
) )
def test_fetch_files_with_file_segment(llm_node): def test_fetch_files_with_file_segment():
file = File( file = File(
id="1", id="1",
tenant_id="test", tenant_id="test",
@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node):
related_id="1", related_id="1",
storage_key="", storage_key="",
) )
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [file] assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node): def test_fetch_files_with_array_file_segment():
files = [ files = [
File( File(
id="1", id="1",
@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node):
storage_key="", storage_key="",
), ),
] ]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == files assert result == files
def test_fetch_files_with_none_segment(llm_node): def test_fetch_files_with_none_segment():
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [] assert result == []
def test_fetch_files_with_array_any_segment(llm_node): def test_fetch_files_with_array_any_segment():
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [] assert result == []
def test_fetch_files_with_non_existent_variable(llm_node): def test_fetch_files_with_non_existent_variable():
result = llm_node._fetch_files(selector=["sys", "files"]) variable_pool = VariablePool()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [] assert result == []

Loading…
Cancel
Save