|
|
|
|
@ -1,10 +1,12 @@
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
from collections.abc import Sequence
|
|
|
|
|
from threading import Lock
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
import contexts
|
|
|
|
|
from core.entities import DEFAULT_PLUGIN_ID
|
|
|
|
|
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
|
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
|
|
|
|
@ -71,13 +73,24 @@ class ModelProviderFactory:
|
|
|
|
|
Get all plugin model providers
|
|
|
|
|
:return: list of plugin model providers
|
|
|
|
|
"""
|
|
|
|
|
# Fetch plugin model providers
|
|
|
|
|
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
|
|
|
|
# check if context is set
|
|
|
|
|
try:
|
|
|
|
|
contexts.plugin_model_providers.get()
|
|
|
|
|
except LookupError:
|
|
|
|
|
contexts.plugin_model_providers.set([])
|
|
|
|
|
contexts.plugin_model_providers_lock.set(Lock())
|
|
|
|
|
|
|
|
|
|
for provider in plugin_providers:
|
|
|
|
|
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
|
|
|
|
with contexts.plugin_model_providers_lock.get():
|
|
|
|
|
plugin_model_providers = contexts.plugin_model_providers.get()
|
|
|
|
|
|
|
|
|
|
# Fetch plugin model providers
|
|
|
|
|
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
|
|
|
|
|
|
|
|
|
for provider in plugin_providers:
|
|
|
|
|
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
|
|
|
|
plugin_model_providers.append(provider)
|
|
|
|
|
|
|
|
|
|
return plugin_providers
|
|
|
|
|
return plugin_model_providers
|
|
|
|
|
|
|
|
|
|
def get_provider_schema(self, provider: str) -> ProviderEntity:
|
|
|
|
|
"""
|
|
|
|
|
|