|
|
|
|
@ -2,6 +2,7 @@ import json
|
|
|
|
|
from typing import Type
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
from langchain.embeddings import XinferenceEmbeddings
|
|
|
|
|
|
|
|
|
|
from core.helper import encrypter
|
|
|
|
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
|
|
|
|
@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
|
|
|
|
|
'model_uid': credentials['model_uid'],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llm = XinferenceLLM(
|
|
|
|
|
**credential_kwargs
|
|
|
|
|
)
|
|
|
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
|
|
|
llm = XinferenceLLM(
|
|
|
|
|
**credential_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
llm("ping")
|
|
|
|
|
elif model_type == ModelType.EMBEDDINGS:
|
|
|
|
|
embedding = XinferenceEmbeddings(
|
|
|
|
|
**credential_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
llm("ping")
|
|
|
|
|
embedding.embed_query("ping")
|
|
|
|
|
except Exception as ex:
|
|
|
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|
|
|
|
|
|
@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
|
|
|
|
|
:param credentials:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
extra_credentials = cls._get_extra_credentials(credentials)
|
|
|
|
|
credentials.update(extra_credentials)
|
|
|
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
|
|
|
extra_credentials = cls._get_extra_credentials(credentials)
|
|
|
|
|
credentials.update(extra_credentials)
|
|
|
|
|
|
|
|
|
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
|
|
|
|
|
|
|
|
|
|