|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
|
|
|
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
|
|
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
|
|
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
|
|
|
from core.entities.provider_entities import (
|
|
|
|
from core.entities.provider_entities import (
|
|
|
|
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from core.helper import encrypter
|
|
|
|
from core.helper import encrypter
|
|
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
|
|
|
|
|
|
from core.helper.position_helper import is_filtered
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
|
from core.model_runtime.entities.provider_entities import (
|
|
|
|
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
|
|
|
CredentialFormSchema,
|
|
|
|
|
|
|
|
FormType,
|
|
|
|
|
|
|
|
ProviderEntity,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
from core.model_runtime.model_providers import model_provider_factory
|
|
|
|
from core.model_runtime.model_providers import model_provider_factory
|
|
|
|
from extensions import ext_hosting_provider
|
|
|
|
from extensions import ext_hosting_provider
|
|
|
|
from extensions.ext_database import db
|
|
|
|
from extensions.ext_database import db
|
|
|
|
@ -45,6 +43,7 @@ class ProviderManager:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
|
|
|
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.decoding_rsa_key = None
|
|
|
|
self.decoding_rsa_key = None
|
|
|
|
self.decoding_cipher_rsa = None
|
|
|
|
self.decoding_cipher_rsa = None
|
|
|
|
@ -117,6 +116,16 @@ class ProviderManager:
|
|
|
|
|
|
|
|
|
|
|
|
# Construct ProviderConfiguration objects for each provider
|
|
|
|
# Construct ProviderConfiguration objects for each provider
|
|
|
|
for provider_entity in provider_entities:
|
|
|
|
for provider_entity in provider_entities:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# handle include, exclude
|
|
|
|
|
|
|
|
if is_filtered(
|
|
|
|
|
|
|
|
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
|
|
|
|
|
|
|
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
|
|
|
|
|
|
|
data=provider_entity,
|
|
|
|
|
|
|
|
name_func=lambda x: x.provider,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
provider_name = provider_entity.provider
|
|
|
|
provider_name = provider_entity.provider
|
|
|
|
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
|
|
|
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
|
|
|
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
|
|
|
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
|
|
|
@ -271,6 +280,24 @@ class ProviderManager:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get names of first model and its provider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param tenant_id: workspace id
|
|
|
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
|
|
|
:return: provider name, model name
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
provider_configurations = self.get_configurations(tenant_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get available models from provider_configurations
|
|
|
|
|
|
|
|
all_models = provider_configurations.get_models(
|
|
|
|
|
|
|
|
model_type=model_type,
|
|
|
|
|
|
|
|
only_active=False
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return all_models[0].provider.provider, all_models[0].model
|
|
|
|
|
|
|
|
|
|
|
|
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
|
|
|
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
|
|
|
-> TenantDefaultModel:
|
|
|
|
-> TenantDefaultModel:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|