|
|
|
|
@ -5,11 +5,15 @@ from langchain.callbacks import CallbackManager
|
|
|
|
|
from langchain.llms.fake import FakeListLLM
|
|
|
|
|
|
|
|
|
|
from core.constant import llm_constant
|
|
|
|
|
from core.llm.error import ProviderTokenNotInitError
|
|
|
|
|
from core.llm.provider.base import BaseProvider
|
|
|
|
|
from core.llm.provider.llm_provider_service import LLMProviderService
|
|
|
|
|
from core.llm.provider.openai_provider import OpenAIProvider
|
|
|
|
|
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
|
|
|
|
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
|
|
|
|
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
|
|
|
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
|
|
|
|
from models.provider import ProviderType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMBuilder:
|
|
|
|
|
@ -34,7 +38,7 @@ class LLMBuilder:
|
|
|
|
|
if model_name == 'fake':
|
|
|
|
|
return FakeListLLM(responses=[])
|
|
|
|
|
|
|
|
|
|
provider = current_app.config.get('DEFAULT_LLM_PROVIDER')
|
|
|
|
|
provider = cls.get_default_provider(tenant_id)
|
|
|
|
|
|
|
|
|
|
mode = cls.get_mode_by_model(model_name)
|
|
|
|
|
if mode == 'chat':
|
|
|
|
|
@ -50,7 +54,7 @@ class LLMBuilder:
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"model name {model_name} is not supported.")
|
|
|
|
|
|
|
|
|
|
model_credentials = cls.get_model_credentials(tenant_id, model_name)
|
|
|
|
|
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
|
|
|
|
|
|
|
|
|
return llm_cls(
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
@ -96,7 +100,7 @@ class LLMBuilder:
|
|
|
|
|
raise ValueError(f"model name {model_name} is not supported.")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
|
|
|
|
|
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
|
|
|
|
|
Raises an exception if the model_name is not found or if the provider is not found.
|
|
|
|
|
@ -108,7 +112,19 @@ class LLMBuilder:
|
|
|
|
|
# raise Exception('model {} not found'.format(model_name))
|
|
|
|
|
|
|
|
|
|
# model_provider = llm_constant.models[model_name]
|
|
|
|
|
model_provider = current_app.config.get('DEFAULT_LLM_PROVIDER')
|
|
|
|
|
|
|
|
|
|
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
|
|
|
|
return provider_service.get_credentials(model_name)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_default_provider(cls, tenant_id: str) -> str:
|
|
|
|
|
provider = BaseProvider.get_valid_provider(tenant_id)
|
|
|
|
|
if not provider:
|
|
|
|
|
raise ProviderTokenNotInitError()
|
|
|
|
|
|
|
|
|
|
if provider.provider_type == ProviderType.SYSTEM.value:
|
|
|
|
|
provider_name = 'openai'
|
|
|
|
|
else:
|
|
|
|
|
provider_name = provider.provider_name
|
|
|
|
|
|
|
|
|
|
return provider_name
|
|
|
|
|
|