diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 4b8214019c..b3affc91a6 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,7 +7,6 @@ from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import or_ from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity @@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel): else [], ) - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: + def _get_custom_provider_credentials(self) -> Provider | None: """ - Validate custom credentials. - :param credentials: provider credentials - :return: + Get custom provider credentials. """ # get provider model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): - provider_record = ( - db.session.query(Provider) - .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - or_( - Provider.provider_name == model_provider_id.provider_name, - Provider.provider_name == self.provider.provider, - ), - ) - .first() - ) - else: - provider_record = ( - db.session.query(Provider) - .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name == self.provider.provider, - ) - .first() + provider_names.append(model_provider_id.provider_name) + + provider_record = ( + db.session.query(Provider) + .filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_name.in_(provider_names), ) + .first() + ) + + return provider_record + + def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: + """ + Validate custom credentials. + :param credentials: provider credentials + :return: + """ + provider_record = self._get_custom_provider_credentials() # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( @@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = ( - db.session.query(Provider) - .filter( - Provider.tenant_id == self.tenant_id, - or_( - Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name, - Provider.provider_name == self.provider.provider, - ), - Provider.provider_type == ProviderType.CUSTOM.value, - ) - .first() - ) + provider_record = self._get_custom_provider_credentials() # delete provider if provider_record: @@ -349,29 +335,47 @@ class ProviderConfiguration(BaseModel): return None - def custom_model_credentials_validate( - self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel | None, dict]: + def _get_custom_model_credentials( + self, + model_type: ModelType, + model: str, + ) -> ProviderModel | None: """ - Validate custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials - :return: + Get custom model credentials. """ # get provider model + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + provider_model_record = ( db.session.query(ProviderModel) .filter( ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, + ProviderModel.provider_name.in_(provider_names), ProviderModel.model_name == model, ProviderModel.model_type == model_type.to_origin_model_type(), ) .first() ) + return provider_model_record + + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[ProviderModel | None, dict]: + """ + Validate custom model credentials. + + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + # get provider model + provider_model_record = self._get_custom_model_credentials(model_type, model) + # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas @@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = ( - db.session.query(ProviderModel) - .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), - ) - .first() - ) + provider_model_record = self._get_custom_model_credentials(model_type, model) # delete provider model if provider_model_record: @@ -475,24 +470,35 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache.delete() - def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: """ - Enable model. - :param model_type: model type - :param model: model name - :return: + Get provider model setting. """ - model_setting = ( + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + + return ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) .first() ) + def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = self._get_provider_model_setting(model_type, model) + if model_setting: model_setting.enabled = True model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) @@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = ( - db.session.query(ProviderModelSetting) - .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + model_setting = self._get_provider_model_setting(model_type, model) if model_setting: model_setting.enabled = False @@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + return self._get_provider_model_setting(model_type, model) + + def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: + """ + Get load balancing config. + """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + return ( - db.session.query(ProviderModelSetting) + db.session.query(LoadBalancingModelConfig) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name.in_(provider_names), + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, ) .first() ) @@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + load_balancing_config_count = ( db.session.query(LoadBalancingModelConfig) .filter( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) @@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel): if load_balancing_config_count <= 1: raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = ( - db.session.query(ProviderModelSetting) - .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + model_setting = self._get_provider_model_setting(model_type, model) if model_setting: model_setting.load_balancing_enabled = True @@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + model_setting = ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) @@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel): return # get preferred provider + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + preferred_model_provider = ( db.session.query(TenantPreferredModelProvider) .filter( TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider, + TenantPreferredModelProvider.provider_name.in_(provider_names), ) .first() )