refactor: Refactors provider retrieval using SQLAlchemy session

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20586/head
-LAN- 12 months ago
parent 582aef1b98
commit d2d7f00144
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
# TODO: Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
# Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict
@staticmethod

@ -103,9 +103,6 @@ def test_execute_llm(setup_model_mock):
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Mock db.session.close()
db.session.close = MagicMock()
node._fetch_model_config = get_mocked_fetch_model_config(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",

Loading…
Cancel
Save