|
|
|
|
@ -3,7 +3,9 @@ from collections import defaultdict
|
|
|
|
|
from json import JSONDecodeError
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
|
|
|
|
@ -393,19 +395,13 @@ class ProviderManager:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
|
|
|
|
"""
|
|
|
|
|
Get all provider records of the workspace.
|
|
|
|
|
|
|
|
|
|
:param tenant_id: workspace id
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
|
|
|
|
|
|
|
|
|
provider_name_to_provider_records_dict = defaultdict(list)
|
|
|
|
|
for provider in providers:
|
|
|
|
|
# TODO: Use provider name with prefix after the data migration
|
|
|
|
|
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
|
|
|
|
|
|
|
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
|
|
|
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
|
|
|
|
providers = session.scalars(stmt)
|
|
|
|
|
for provider in providers:
|
|
|
|
|
# Use provider name with prefix after the data migration
|
|
|
|
|
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
|
|
|
|
return provider_name_to_provider_records_dict
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|