From 7c1d842cfe089ac6d9275c896d03a49d3c64af36 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:21:58 +0800 Subject: [PATCH] (1.0) fix: invalid default model provider (#13572) --- .../model_providers/model_provider_factory.py | 6 +++++- api/core/plugin/entities/plugin.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 23596558db..b311f069a8 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -20,6 +20,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator +from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.manager.asset import PluginAssetManager from core.plugin.manager.model import PluginModelManager @@ -112,6 +113,9 @@ class ModelProviderFactory: :param provider: provider name :return: provider schema """ + if "/" not in provider: + provider = str(ModelProviderID(provider)) + # fetch plugin model providers plugin_model_provider_entities = self.get_plugin_model_providers() @@ -363,4 +367,4 @@ class ModelProviderFactory: plugin_id = "/".join(provider.split("/")[:-1]) provider_name = provider.split("/")[-1] - return plugin_id, provider_name + return str(plugin_id), provider_name diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index ee65e86826..aa78eb919c 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -169,6 +169,13 @@ class GenericProviderID: return f"{self.organization}/{self.plugin_name}" +class ModelProviderID(GenericProviderID): + def __init__(self, value: str, is_hardcoded: bool = False) -> None: + super().__init__(value, is_hardcoded) + if self.organization == "langgenius" and self.provider_name == "google": + self.provider_name = "gemini" + + class PluginDependency(BaseModel): class Type(enum.StrEnum): Github = PluginInstallationSource.Github.value