|
|
|
|
@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
|
|
if server_url.endswith('/'):
|
|
|
|
|
server_url = server_url[:-1]
|
|
|
|
|
|
|
|
|
|
client = Client(base_url=server_url)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
handle = client.get_model(model_uid=model_uid)
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
raise InvokeAuthorizationError(e)
|
|
|
|
|
|
|
|
|
|
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
|
|
|
|
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
|
|
|
|
embeddings = handle.create_embedding(input=texts)
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
raise InvokeServerUnavailableError(e)
|
|
|
|
|
@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
|
|
|
|
|
|
|
if extra_args.max_tokens:
|
|
|
|
|
credentials['max_tokens'] = extra_args.max_tokens
|
|
|
|
|
if server_url.endswith('/'):
|
|
|
|
|
server_url = server_url[:-1]
|
|
|
|
|
|
|
|
|
|
client = Client(base_url=server_url)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
handle = client.get_model(model_uid=model_uid)
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
raise InvokeAuthorizationError(e)
|
|
|
|
|
|
|
|
|
|
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
|
|
|
|
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
|
|
|
|
|
|
|
|
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
|
|
|
|
except InvokeAuthorizationError as e:
|
|
|
|
|
@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
|
|
parameter_rules=[]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return entity
|
|
|
|
|
return entity
|
|
|
|
|
|