From d2d7f0014442b07b52f24dc35dba36a57fb9bb6c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 3 Jun 2025 17:01:25 +0800 Subject: [PATCH] refactor: Refactors provider retrieval using SQLAlchemy session Signed-off-by: -LAN- --- api/core/provider_manager.py | 20 ++++++++----------- .../workflow/nodes/test_llm.py | 3 --- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7570200175..76452f63e0 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,9 @@ from collections import defaultdict from json import JSONDecodeError from typing import Any, Optional, cast +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity @@ -393,19 +395,13 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - # TODO: Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) + providers = session.scalars(stmt) + for provider in providers: + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 5fbee266bd..d4c4fc5006 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -103,9 +103,6 @@ def test_execute_llm(setup_model_mock): credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Mock db.session.close() - db.session.close = MagicMock() - node._fetch_model_config = get_mocked_fetch_model_config( provider="langgenius/openai/openai", model="gpt-3.5-turbo",