diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 76452f63e0..488a394679 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -412,17 +412,12 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - # Get all provider model records of the workspace - provider_models = ( - db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) - .all() - ) - provider_name_to_provider_model_records_dict = defaultdict(list) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + provider_models = session.scalars(stmt) + for provider_model in provider_models: + provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) return provider_name_to_provider_model_records_dict @staticmethod @@ -433,17 +428,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = ( - db.session.query(TenantPreferredModelProvider) - .filter(TenantPreferredModelProvider.tenant_id == tenant_id) - .all() - ) - - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } - + provider_name_to_preferred_provider_type_records_dict = {} + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) + preferred_provider_types = session.scalars(stmt) + provider_name_to_preferred_provider_type_records_dict = { + preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types + } return provider_name_to_preferred_provider_type_records_dict @staticmethod @@ -454,18 +446,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = ( - db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() - ) - provider_name_to_provider_model_settings_dict = defaultdict(list) - for provider_model_setting in provider_model_settings: - ( + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) + provider_model_settings = session.scalars(stmt) + for provider_model_setting in provider_model_settings: provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_model_setting ) - ) - return provider_name_to_provider_model_settings_dict @staticmethod @@ -488,15 +476,14 @@ class ProviderManager: if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() - ) - provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - for provider_load_balancing_config in provider_load_balancing_configs: - provider_name_to_provider_load_balancing_model_configs_dict[ - provider_load_balancing_config.provider_name - ].append(provider_load_balancing_config) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) + provider_load_balancing_configs = session.scalars(stmt) + for provider_load_balancing_config in provider_load_balancing_configs: + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) return provider_name_to_provider_load_balancing_model_configs_dict @@ -622,10 +609,9 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if ( - custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{") - ): + if custom_provider_record.encrypted_config is None: + raise ValueError("No credentials found") + if not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -729,7 +715,7 @@ class ProviderManager: return SystemConfiguration(enabled=False) # Convert provider_records to dict - quota_type_to_provider_records_dict = {} + quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue @@ -754,6 +740,11 @@ class ProviderManager: else: provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] + if provider_record.quota_used is None: + raise ValueError("quota_used is None") + if provider_record.quota_limit is None: + raise ValueError("quota_limit is None") + quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, @@ -787,10 +778,9 @@ class ProviderManager: cached_provider_credentials = provider_credentials_cache.get() if not cached_provider_credentials: - try: - provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + provider_credentials: dict[str, Any] = {} + if provider_records and provider_records[0].encrypted_config: + provider_credentials = json.loads(provider_records[0].encrypted_config) # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( diff --git a/api/models/provider.py b/api/models/provider.py index 497cbefc61..1e25f0c90f 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,6 +1,9 @@ +from datetime import datetime from enum import Enum +from typing import Optional -from sqlalchemy import func +from sqlalchemy import func, text +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -51,20 +54,24 @@ class Provider(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used = db.Column(db.DateTime, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_type: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'custom'::character varying") + ) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.BigInteger, nullable=True) - quota_used = db.Column(db.BigInteger, default=0) + quota_type: Mapped[Optional[str]] = mapped_column( + db.String(40), nullable=True, server_default=text("''::character varying") + ) + quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -104,15 +111,15 @@ class ProviderModel(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -122,13 +129,13 @@ class TenantDefaultModel(Base): db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base): db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -153,22 +160,24 @@ class ProviderOrder(Base): db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - account_id = db.Column(StringUUID, nullable=False) - payment_product_id = db.Column(db.String(191), nullable=False) - payment_id = db.Column(db.String(191)) - transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) - currency = db.Column(db.String(40)) - total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) - paid_at = db.Column(db.DateTime) - pay_failed_at = db.Column(db.DateTime) - refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_status: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + ) + paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -182,15 +191,15 @@ class ProviderModelSetting(Base): db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base): db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - name = db.Column(db.String(255), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d4c4fc5006..981ffd80d6 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -3,11 +3,16 @@ import os import time import uuid from collections.abc import Generator -from unittest.mock import MagicMock +from decimal import Decimal +from unittest.mock import MagicMock, patch import pytest +from app_factory import create_app +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey @@ -19,13 +24,19 @@ from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType -from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +@pytest.fixture(scope="session") +def app(): + app = create_app() + dify_config.LOGIN_DISABLED = True + return app + + def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -40,13 +51,19 @@ def init_llm_node(config: dict) -> LLMNode: graph = Graph.init(graph_config=graph_config) + # Use proper UUIDs for database compatibility + tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c" + workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" + user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" + init_params = GraphInitParams( - tenant_id="1", - app_id="1", + tenant_id=tenant_id, + app_id=app_id, workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", + workflow_id=workflow_id, graph_config=graph_config, - user_id="1", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, call_depth=0, @@ -77,112 +94,197 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(setup_model_mock): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, +def test_execute_llm(app): + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - "prompt_template": [ - {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) + + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") - # execute node - result = node._run() - assert isinstance(result, Generator) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance + + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): +def test_execute_llm_with_jinja2(app, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", - }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + # Mock db.session.close() + db.session.close = MagicMock() - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") + + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) + + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - # execute node - result = node._run() + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) def test_extract_json():