refactor: Refactors provider retrieval using SQLAlchemy session

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20586/head
-LAN- 12 months ago
parent 582aef1b98
commit d2d7f00144
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod @staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
provider_name_to_provider_records_dict = defaultdict(list) provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers: with Session(db.engine, expire_on_commit=False) as session:
# TODO: Use provider name with prefix after the data migration stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) providers = session.scalars(stmt)
for provider in providers:
# Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict
@staticmethod @staticmethod

@ -103,9 +103,6 @@ def test_execute_llm(setup_model_mock):
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Mock db.session.close()
db.session.close = MagicMock()
node._fetch_model_config = get_mocked_fetch_model_config( node._fetch_model_config = get_mocked_fetch_model_config(
provider="langgenius/openai/openai", provider="langgenius/openai/openai",
model="gpt-3.5-turbo", model="gpt-3.5-turbo",

Loading…
Cancel
Save