diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7570200175..76452f63e0 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,9 @@ from collections import defaultdict from json import JSONDecodeError from typing import Any, Optional, cast +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity @@ -393,19 +395,13 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - # TODO: Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) + providers = session.scalars(stmt) + for provider in providers: + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 5fbee266bd..d4c4fc5006 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -103,9 +103,6 @@ def test_execute_llm(setup_model_mock): credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Mock db.session.close() - db.session.close = MagicMock() - node._fetch_model_config = get_mocked_fetch_model_config( provider="langgenius/openai/openai", model="gpt-3.5-turbo",