test(test_llm): Fix test

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 67269ab61e
commit 678a39a113
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -2,15 +2,10 @@ import json
import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from decimal import Decimal
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
@ -24,9 +19,6 @@ from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE""" """FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
def init_llm_node(config: dict) -> LLMNode: def init_llm_node(config: dict) -> LLMNode:
@ -92,7 +84,7 @@ def init_llm_node(config: dict) -> LLMNode:
return node return node
def test_execute_llm(flask_req_ctx, setup_model_mock): def test_execute_llm():
node = init_llm_node( node = init_llm_node(
config={ config={
"id": "llm", "id": "llm",
@ -100,7 +92,7 @@ def test_execute_llm(flask_req_ctx, setup_model_mock):
"title": "123", "title": "123",
"type": "llm", "type": "llm",
"model": { "model": {
"provider": "langgenius/openai/openai", "provider": "openai",
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"mode": "chat", "mode": "chat",
"completion_params": {}, "completion_params": {},
@ -121,6 +113,61 @@ def test_execute_llm(flask_req_ctx, setup_model_mock):
db.session.close = MagicMock() db.session.close = MagicMock()
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response from mock")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
# execute node # execute node
result = node._run() result = node._run()
assert isinstance(result, Generator) assert isinstance(result, Generator)
@ -137,8 +184,7 @@ def test_execute_llm(flask_req_ctx, setup_model_mock):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_llm_with_jinja2():
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
""" """
Test execute LLM node with jinja2 Test execute LLM node with jinja2
""" """
@ -179,7 +225,16 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
# Create a proper LLM result with real entities # Mock the _fetch_model_config method
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage( mock_usage = LLMUsage(
prompt_tokens=30, prompt_tokens=30,
prompt_unit_price=Decimal("0.001"), prompt_unit_price=Decimal("0.001"),
@ -194,38 +249,36 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
currency="USD", currency="USD",
latency=0.5, latency=0.5,
) )
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult( mock_llm_result = LLMResult(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
prompt_messages=[], prompt_messages=[],
message=mock_message, message=mock_message,
usage=mock_usage, usage=mock_usage,
) )
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes # Create mock model config
mock_model_config = MagicMock() mock_model_config = MagicMock()
mock_model_config.mode = "chat" mock_model_config.mode = "chat"
mock_model_config.provider = "openai" mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo" mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" mock_model_config.parameters = {}
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls # Mock fetch_prompt_messages to avoid database calls
def mock_get_model_instance(_self, **kwargs): def mock_fetch_prompt_messages_2(**_kwargs):
return mock_model_instance from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with ( with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
): ):
# execute node # execute node
result = node._run() result = node._run()

Loading…
Cancel
Save