|
|
|
|
@ -7,7 +7,6 @@ from json import JSONDecodeError
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
|
from sqlalchemy import or_
|
|
|
|
|
|
|
|
|
|
from constants import HIDDEN_VALUE
|
|
|
|
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
|
|
|
|
@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
else [],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
|
|
|
|
def _get_custom_provider_credentials(self) -> Provider | None:
|
|
|
|
|
"""
|
|
|
|
|
Validate custom credentials.
|
|
|
|
|
:param credentials: provider credentials
|
|
|
|
|
:return:
|
|
|
|
|
Get custom provider credentials.
|
|
|
|
|
"""
|
|
|
|
|
# get provider
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_record = (
|
|
|
|
|
db.session.query(Provider)
|
|
|
|
|
.filter(
|
|
|
|
|
Provider.tenant_id == self.tenant_id,
|
|
|
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
|
|
or_(
|
|
|
|
|
Provider.provider_name == model_provider_id.provider_name,
|
|
|
|
|
Provider.provider_name == self.provider.provider,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
provider_record = (
|
|
|
|
|
db.session.query(Provider)
|
|
|
|
|
.filter(
|
|
|
|
|
Provider.tenant_id == self.tenant_id,
|
|
|
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
|
|
Provider.provider_name == self.provider.provider,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
provider_record = (
|
|
|
|
|
db.session.query(Provider)
|
|
|
|
|
.filter(
|
|
|
|
|
Provider.tenant_id == self.tenant_id,
|
|
|
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
|
|
Provider.provider_name.in_(provider_names),
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return provider_record
|
|
|
|
|
|
|
|
|
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
|
|
|
|
"""
|
|
|
|
|
Validate custom credentials.
|
|
|
|
|
:param credentials: provider credentials
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
provider_record = self._get_custom_provider_credentials()
|
|
|
|
|
|
|
|
|
|
# Get provider credential secret variables
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
|
|
@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
# get provider
|
|
|
|
|
provider_record = (
|
|
|
|
|
db.session.query(Provider)
|
|
|
|
|
.filter(
|
|
|
|
|
Provider.tenant_id == self.tenant_id,
|
|
|
|
|
or_(
|
|
|
|
|
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
|
|
|
|
Provider.provider_name == self.provider.provider,
|
|
|
|
|
),
|
|
|
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
provider_record = self._get_custom_provider_credentials()
|
|
|
|
|
|
|
|
|
|
# delete provider
|
|
|
|
|
if provider_record:
|
|
|
|
|
@ -349,29 +335,47 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def custom_model_credentials_validate(
|
|
|
|
|
self, model_type: ModelType, model: str, credentials: dict
|
|
|
|
|
) -> tuple[ProviderModel | None, dict]:
|
|
|
|
|
def _get_custom_model_credentials(
|
|
|
|
|
self,
|
|
|
|
|
model_type: ModelType,
|
|
|
|
|
model: str,
|
|
|
|
|
) -> ProviderModel | None:
|
|
|
|
|
"""
|
|
|
|
|
Validate custom model credentials.
|
|
|
|
|
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
:param model: model name
|
|
|
|
|
:param credentials: model credentials
|
|
|
|
|
:return:
|
|
|
|
|
Get custom model credentials.
|
|
|
|
|
"""
|
|
|
|
|
# get provider model
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
provider_model_record = (
|
|
|
|
|
db.session.query(ProviderModel)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModel.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModel.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModel.provider_name.in_(provider_names),
|
|
|
|
|
ProviderModel.model_name == model,
|
|
|
|
|
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return provider_model_record
|
|
|
|
|
|
|
|
|
|
def custom_model_credentials_validate(
|
|
|
|
|
self, model_type: ModelType, model: str, credentials: dict
|
|
|
|
|
) -> tuple[ProviderModel | None, dict]:
|
|
|
|
|
"""
|
|
|
|
|
Validate custom model credentials.
|
|
|
|
|
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
:param model: model name
|
|
|
|
|
:param credentials: model credentials
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
# get provider model
|
|
|
|
|
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
|
|
|
|
|
|
|
|
|
# Get provider credential secret variables
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
|
|
|
@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
# get provider model
|
|
|
|
|
provider_model_record = (
|
|
|
|
|
db.session.query(ProviderModel)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModel.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModel.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModel.model_name == model,
|
|
|
|
|
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
|
|
|
|
|
|
|
|
|
# delete provider model
|
|
|
|
|
if provider_model_record:
|
|
|
|
|
@ -475,24 +470,35 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
|
|
|
|
|
|
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
|
|
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
|
|
|
|
|
"""
|
|
|
|
|
Enable model.
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
Get provider model setting.
|
|
|
|
|
"""
|
|
|
|
|
model_setting = (
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelSetting.provider_name.in_(provider_names),
|
|
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
|
|
"""
|
|
|
|
|
Enable model.
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
model_setting = self._get_provider_model_setting(model_type, model)
|
|
|
|
|
|
|
|
|
|
if model_setting:
|
|
|
|
|
model_setting.enabled = True
|
|
|
|
|
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
|
|
@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
model_setting = (
|
|
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
model_setting = self._get_provider_model_setting(model_type, model)
|
|
|
|
|
|
|
|
|
|
if model_setting:
|
|
|
|
|
model_setting.enabled = False
|
|
|
|
|
@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
return self._get_provider_model_setting(model_type, model)
|
|
|
|
|
|
|
|
|
|
def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
|
|
|
|
|
"""
|
|
|
|
|
Get load balancing config.
|
|
|
|
|
"""
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
|
|
db.session.query(LoadBalancingModelConfig)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
|
|
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
|
|
|
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
LoadBalancingModelConfig.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
load_balancing_config_count = (
|
|
|
|
|
db.session.query(LoadBalancingModelConfig)
|
|
|
|
|
.filter(
|
|
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
|
|
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
|
|
|
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
LoadBalancingModelConfig.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
if load_balancing_config_count <= 1:
|
|
|
|
|
raise ValueError("Model load balancing configuration must be more than 1.")
|
|
|
|
|
|
|
|
|
|
model_setting = (
|
|
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
model_setting = self._get_provider_model_setting(model_type, model)
|
|
|
|
|
|
|
|
|
|
if model_setting:
|
|
|
|
|
model_setting.load_balancing_enabled = True
|
|
|
|
|
@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
model_setting = (
|
|
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelSetting.provider_name.in_(provider_names),
|
|
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# get preferred provider
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
preferred_model_provider = (
|
|
|
|
|
db.session.query(TenantPreferredModelProvider)
|
|
|
|
|
.filter(
|
|
|
|
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
|
|
|
|
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
|
|
|
|
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|