refactor: Refactors database sessions and typing annotations

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20586/head
-LAN- 12 months ago
parent d2d7f00144
commit 29abcb465e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -412,17 +412,12 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
# Get all provider model records of the workspace
provider_models = (
db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
.all()
)
provider_name_to_provider_model_records_dict = defaultdict(list) provider_name_to_provider_model_records_dict = defaultdict(list)
for provider_model in provider_models: with Session(db.engine, expire_on_commit=False) as session:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict return provider_name_to_provider_model_records_dict
@staticmethod @staticmethod
@ -433,17 +428,14 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
preferred_provider_types = ( provider_name_to_preferred_provider_type_records_dict = {}
db.session.query(TenantPreferredModelProvider) with Session(db.engine, expire_on_commit=False) as session:
.filter(TenantPreferredModelProvider.tenant_id == tenant_id) stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
.all() preferred_provider_types = session.scalars(stmt)
) provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
provider_name_to_preferred_provider_type_records_dict = { for preferred_provider_type in preferred_provider_types
preferred_provider_type.provider_name: preferred_provider_type }
for preferred_provider_type in preferred_provider_types
}
return provider_name_to_preferred_provider_type_records_dict return provider_name_to_preferred_provider_type_records_dict
@staticmethod @staticmethod
@ -454,18 +446,14 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
provider_model_settings = (
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
)
provider_name_to_provider_model_settings_dict = defaultdict(list) provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings: with Session(db.engine, expire_on_commit=False) as session:
( stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting provider_model_setting
) )
)
return provider_name_to_provider_model_settings_dict return provider_name_to_provider_model_settings_dict
@staticmethod @staticmethod
@ -488,15 +476,14 @@ class ProviderManager:
if not model_load_balancing_enabled: if not model_load_balancing_enabled:
return {} return {}
provider_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
)
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs: with Session(db.engine, expire_on_commit=False) as session:
provider_name_to_provider_load_balancing_model_configs_dict[ stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_config.provider_name provider_load_balancing_configs = session.scalars(stmt)
].append(provider_load_balancing_config) for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
return provider_name_to_provider_load_balancing_model_configs_dict return provider_name_to_provider_load_balancing_model_configs_dict
@ -622,10 +609,9 @@ class ProviderManager:
if not cached_provider_credentials: if not cached_provider_credentials:
try: try:
# fix origin data # fix origin data
if ( if custom_provider_record.encrypted_config is None:
custom_provider_record.encrypted_config raise ValueError("No credentials found")
and not custom_provider_record.encrypted_config.startswith("{") if not custom_provider_record.encrypted_config.startswith("{"):
):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else: else:
provider_credentials = json.loads(custom_provider_record.encrypted_config) provider_credentials = json.loads(custom_provider_record.encrypted_config)
@ -729,7 +715,7 @@ class ProviderManager:
return SystemConfiguration(enabled=False) return SystemConfiguration(enabled=False)
# Convert provider_records to dict # Convert provider_records to dict
quota_type_to_provider_records_dict = {} quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records: for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value: if provider_record.provider_type != ProviderType.SYSTEM.value:
continue continue
@ -754,6 +740,11 @@ class ProviderManager:
else: else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
if provider_record.quota_used is None:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
quota_configuration = QuotaConfiguration( quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type, quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
@ -787,10 +778,9 @@ class ProviderManager:
cached_provider_credentials = provider_credentials_cache.get() cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials: if not cached_provider_credentials:
try: provider_credentials: dict[str, Any] = {}
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) if provider_records and provider_records[0].encrypted_config:
except JSONDecodeError: provider_credentials = json.loads(provider_records[0].encrypted_config)
provider_credentials = {}
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(

@ -1,6 +1,9 @@
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional
from sqlalchemy import func from sqlalchemy import func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base from .base import Base
from .engine import db from .engine import db
@ -51,20 +54,24 @@ class Provider(Base):
), ),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) provider_type: Mapped[str] = mapped_column(
encrypted_config = db.Column(db.Text, nullable=True) db.String(40), nullable=False, server_default=text("'custom'::character varying")
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) )
last_used = db.Column(db.DateTime, nullable=True) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) quota_type: Mapped[Optional[str]] = mapped_column(
quota_limit = db.Column(db.BigInteger, nullable=True) db.String(40), nullable=True, server_default=text("''::character varying")
quota_used = db.Column(db.BigInteger, default=0) )
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self): def __repr__(self):
return ( return (
@ -104,15 +111,15 @@ class ProviderModel(Base):
), ),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(Base): class TenantDefaultModel(Base):
@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base): class TenantPreferredModelProvider(Base):
@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False) preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base): class ProviderOrder(Base):
@ -153,22 +160,24 @@ class ProviderOrder(Base):
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
account_id = db.Column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False) payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191)) payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
transaction_id = db.Column(db.String(191)) transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
currency = db.Column(db.String(40)) currency: Mapped[Optional[str]] = mapped_column(db.String(40))
total_amount = db.Column(db.Integer) total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) payment_status: Mapped[str] = mapped_column(
paid_at = db.Column(db.DateTime) db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
pay_failed_at = db.Column(db.DateTime) )
refunded_at = db.Column(db.DateTime) paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base): class ProviderModelSetting(Base):
@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base): class LoadBalancingModelConfig(Base):
@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
name = db.Column(db.String(255), nullable=False) name: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@ -3,11 +3,16 @@ import os
import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from unittest.mock import MagicMock from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest import pytest
from app_factory import create_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -19,13 +24,19 @@ from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
"""FOR MOCK FIXTURES, DO NOT REMOVE""" """FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.fixture(scope="session")
def app():
app = create_app()
dify_config.LOGIN_DISABLED = True
return app
def init_llm_node(config: dict) -> LLMNode: def init_llm_node(config: dict) -> LLMNode:
graph_config = { graph_config = {
"edges": [ "edges": [
@ -40,13 +51,19 @@ def init_llm_node(config: dict) -> LLMNode:
graph = Graph.init(graph_config=graph_config) graph = Graph.init(graph_config=graph_config)
# Use proper UUIDs for database compatibility
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
init_params = GraphInitParams( init_params = GraphInitParams(
tenant_id="1", tenant_id=tenant_id,
app_id="1", app_id=app_id,
workflow_type=WorkflowType.WORKFLOW, workflow_type=WorkflowType.WORKFLOW,
workflow_id="1", workflow_id=workflow_id,
graph_config=graph_config, graph_config=graph_config,
user_id="1", user_id=user_id,
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0, call_depth=0,
@ -77,112 +94,197 @@ def init_llm_node(config: dict) -> LLMNode:
return node return node
def test_execute_llm(setup_model_mock): def test_execute_llm(app):
node = init_llm_node( with app.app_context():
config={ node = init_llm_node(
"id": "llm", config={
"data": { "id": "llm",
"title": "123", "data": {
"type": "llm", "title": "123",
"model": { "type": "llm",
"provider": "langgenius/openai/openai", "model": {
"name": "gpt-3.5-turbo", "provider": "langgenius/openai/openai",
"mode": "chat", "name": "gpt-3.5-turbo",
"completion_params": {}, "mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
}, },
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
}, },
}, )
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
node._fetch_model_config = get_mocked_fetch_model_config( # Create a proper LLM result with real entities
provider="langgenius/openai/openai", mock_usage = LLMUsage(
model="gpt-3.5-turbo", prompt_tokens=30,
mode="chat", prompt_unit_price=Decimal("0.001"),
credentials=credentials, prompt_price_unit=Decimal("1000"),
) prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
# execute node mock_llm_result = LLMResult(
result = node._run() model="gpt-3.5-turbo",
assert isinstance(result, Generator) prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
for item in result: # Create a simple mock model instance that doesn't call real providers
if isinstance(item, RunCompletedEvent): mock_model_instance = MagicMock()
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED mock_model_instance.invoke_llm.return_value = mock_llm_result
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None # Create a simple mock model config with required attributes
assert item.run_result.outputs.get("text") is not None mock_model_config = MagicMock()
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
""" """
Test execute LLM node with jinja2 Test execute LLM node with jinja2
""" """
node = init_llm_node( with app.app_context():
config={ node = init_llm_node(
"id": "llm", config={
"data": { "id": "llm",
"title": "123", "data": {
"type": "llm", "title": "123",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, "type": "llm",
"prompt_config": { "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"jinja2_variables": [ "prompt_config": {
{"variable": "sys_query", "value_selector": ["sys", "query"]}, "jinja2_variables": [
{"variable": "output", "value_selector": ["abc", "output"]}, {"variable": "sys_query", "value_selector": ["sys", "query"]},
] {"variable": "output", "value_selector": ["abc", "output"]},
}, ]
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
}, },
], "prompt_template": [
"memory": None, {
"context": {"enabled": False}, "role": "system",
"vision": {"enabled": False}, "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
}, },
}, )
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} # Mock db.session.close()
db.session.close = MagicMock()
# Mock db.session.close() # Create a proper LLM result with real entities
db.session.close = MagicMock() mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
node._fetch_model_config = get_mocked_fetch_model_config( mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
provider="langgenius/openai/openai",
model="gpt-3.5-turbo", mock_llm_result = LLMResult(
mode="chat", model="gpt-3.5-turbo",
credentials=credentials, prompt_messages=[],
) message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# execute node with (
result = node._run() patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
for item in result: for item in result:
if isinstance(item, RunCompletedEvent): if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data) assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data) assert "what's the weather today?" in json.dumps(item.run_result.process_data)
def test_extract_json(): def test_extract_json():

Loading…
Cancel
Save