Add Azure AI Studio as provider (#7549)
Co-authored-by: Hélio Lúcio <canais.hlucio@voegol.com.br>pull/7690/head
parent
162faee4f2
commit
7b7576ad55
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
@ -0,0 +1,17 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAIStudioProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
provider: azure_ai_studio
|
||||||
|
label:
|
||||||
|
zh_Hans: Azure AI Studio
|
||||||
|
en_US: Azure AI Studio
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.png
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
description:
|
||||||
|
en_US: Azure AI Studio
|
||||||
|
zh_Hans: Azure AI Studio
|
||||||
|
background: "#93c5fd"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: How to deploy customized model on Azure AI Studio
|
||||||
|
zh_Hans: 如何在Azure AI Studio上的私有化部署的模型
|
||||||
|
url:
|
||||||
|
en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models
|
||||||
|
zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- customizable-model
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: endpoint
|
||||||
|
label:
|
||||||
|
en_US: Azure AI Studio Endpoint
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 请输入你的Azure AI Studio推理端点
|
||||||
|
en_US: 'Enter your API Endpoint, eg: https://example.com'
|
||||||
|
- variable: api_key
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
zh_Hans: API Key
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Azure AI Studio API Key
|
||||||
|
zh_Hans: 在此输入您的 Azure AI Studio API Key
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- variable: jwt_token
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: JWT Token
|
||||||
|
zh_Hans: JWT令牌
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Azure AI Studio JWT Token
|
||||||
|
zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: rerank
|
||||||
@ -0,0 +1,334 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from azure.ai.inference import ChatCompletionsClient
|
||||||
|
from azure.ai.inference.models import StreamingChatCompletionsUpdate
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
from azure.core.exceptions import (
|
||||||
|
ClientAuthenticationError,
|
||||||
|
DecodeError,
|
||||||
|
DeserializationError,
|
||||||
|
HttpResponseError,
|
||||||
|
ResourceExistsError,
|
||||||
|
ResourceModifiedError,
|
||||||
|
ResourceNotFoundError,
|
||||||
|
ResourceNotModifiedError,
|
||||||
|
SerializationError,
|
||||||
|
ServiceRequestError,
|
||||||
|
ServiceResponseError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
I18nObject,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||||
|
"""
|
||||||
|
Model class for Azure AI Studio large language model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any = None
|
||||||
|
|
||||||
|
from azure.ai.inference.models import StreamingChatCompletionsUpdate
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.client:
|
||||||
|
endpoint = credentials.get("endpoint")
|
||||||
|
api_key = credentials.get("api_key")
|
||||||
|
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
||||||
|
|
||||||
|
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": model_parameters.get("max_tokens", 4096),
|
||||||
|
"temperature": model_parameters.get("temperature", 0),
|
||||||
|
"top_p": model_parameters.get("top_p", 1),
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
payload["stop"] = stop
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = [tool.model_dump() for tool in tools]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.complete(**payload)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_stream_response(response, model, prompt_messages)
|
||||||
|
else:
|
||||||
|
return self._handle_non_stream_response(response, model, prompt_messages, credentials)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
|
for chunk in response:
|
||||||
|
if isinstance(chunk, StreamingChatCompletionsUpdate):
|
||||||
|
if chunk.choices:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.content:
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=delta.content, tool_calls=[]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_non_stream_response(
|
||||||
|
self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict
|
||||||
|
) -> LLMResult:
|
||||||
|
assistant_text = response.choices[0].message.content
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens
|
||||||
|
)
|
||||||
|
result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage)
|
||||||
|
|
||||||
|
if hasattr(response, "system_fingerprint"):
|
||||||
|
result.system_fingerprint = response.system_fingerprint
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _invoke_result_generator(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
result: Generator,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Invoke result generator
|
||||||
|
|
||||||
|
:param result: result generator
|
||||||
|
:return: result generator
|
||||||
|
"""
|
||||||
|
callbacks = callbacks or []
|
||||||
|
prompt_message = AssistantPromptMessage(content="")
|
||||||
|
usage = None
|
||||||
|
system_fingerprint = None
|
||||||
|
real_model = model
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in result:
|
||||||
|
if isinstance(chunk, dict):
|
||||||
|
content = chunk["choices"][0]["message"]["content"]
|
||||||
|
usage = chunk["usage"]
|
||||||
|
chunk = LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=content, tool_calls=[]),
|
||||||
|
),
|
||||||
|
system_fingerprint=chunk.get("system_fingerprint"),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
self._trigger_new_chunk_callbacks(
|
||||||
|
chunk=chunk,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_message.content += chunk.delta.message.content
|
||||||
|
real_model = chunk.model
|
||||||
|
if hasattr(chunk.delta, "usage"):
|
||||||
|
usage = chunk.delta.usage
|
||||||
|
|
||||||
|
if chunk.system_fingerprint:
|
||||||
|
system_fingerprint = chunk.system_fingerprint
|
||||||
|
except Exception as e:
|
||||||
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
self._trigger_after_invoke_callbacks(
|
||||||
|
model=model,
|
||||||
|
result=LLMResult(
|
||||||
|
model=real_model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=prompt_message,
|
||||||
|
usage=usage if usage else LLMUsage.empty_usage(),
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
|
),
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_tokens(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Implement token counting logic here
|
||||||
|
# Might need to use a tokenizer specific to the Azure AI Studio model
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
endpoint = credentials.get("endpoint")
|
||||||
|
api_key = credentials.get("api_key")
|
||||||
|
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
||||||
|
client.get_model_info()
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
ServiceRequestError,
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
ServiceResponseError,
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
ClientAuthenticationError,
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
HttpResponseError,
|
||||||
|
DecodeError,
|
||||||
|
ResourceExistsError,
|
||||||
|
ResourceNotFoundError,
|
||||||
|
ResourceModifiedError,
|
||||||
|
ResourceNotModifiedError,
|
||||||
|
SerializationError,
|
||||||
|
DeserializationError,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
Used to define customizable model schema
|
||||||
|
"""
|
||||||
|
rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name="temperature",
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
use_template="temperature",
|
||||||
|
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_p",
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
use_template="top_p",
|
||||||
|
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="max_tokens",
|
||||||
|
type=ParameterType.INT,
|
||||||
|
use_template="max_tokens",
|
||||||
|
min=1,
|
||||||
|
default=512,
|
||||||
|
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[],
|
||||||
|
model_properties={},
|
||||||
|
parameter_rules=rules,
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
import urllib.request
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for Azure AI Studio rerank model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _allow_self_signed_https(self, allowed):
|
||||||
|
# bypass the server certificate verification on client side
|
||||||
|
if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None):
|
||||||
|
ssl._create_default_https_context = ssl._create_unverified_context
|
||||||
|
|
||||||
|
def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str):
|
||||||
|
# self._allow_self_signed_https(True) # Enable if using self-signed certificate
|
||||||
|
|
||||||
|
data = {"inputs": query_input, "docs": docs}
|
||||||
|
|
||||||
|
body = json.dumps(data).encode("utf-8")
|
||||||
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||||
|
|
||||||
|
req = urllib.request.Request(endpoint, body, headers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req) as response:
|
||||||
|
result = response.read()
|
||||||
|
return json.loads(result)
|
||||||
|
except urllib.error.HTTPError as error:
|
||||||
|
logger.error(f"The request failed with status code: {error.code}")
|
||||||
|
logger.error(error.info())
|
||||||
|
logger.error(error.read().decode("utf8", "ignore"))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=[])
|
||||||
|
|
||||||
|
endpoint = credentials.get("endpoint")
|
||||||
|
api_key = credentials.get("jwt_token")
|
||||||
|
|
||||||
|
if not endpoint or not api_key:
|
||||||
|
raise ValueError("Azure endpoint and API key must be provided in credentials")
|
||||||
|
|
||||||
|
result = self._azure_rerank(query, docs, endpoint, api_key)
|
||||||
|
logger.info(f"Azure rerank result: {result}")
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for idx, (doc, score_dict) in enumerate(zip(docs, result)):
|
||||||
|
score = score_dict["score"]
|
||||||
|
rerank_document = RerankDocument(index=idx, text=doc, score=score)
|
||||||
|
|
||||||
|
if score_threshold is None or score >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
rerank_documents.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
if top_n:
|
||||||
|
rerank_documents = rerank_documents[:top_n]
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception in Azure rerank: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [urllib.error.URLError],
|
||||||
|
InvokeServerUnavailableError: [urllib.error.HTTPError],
|
||||||
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
|
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
model_properties={},
|
||||||
|
parameter_rules=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
@ -0,0 +1,113 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
||||||
|
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||||
|
def test_validate_credentials(setup_azure_ai_studio_mock):
|
||||||
|
model = AzureAIStudioLargeLanguageModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="gpt-35-turbo",
|
||||||
|
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="gpt-35-turbo",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_model(setup_azure_ai_studio_mock):
|
||||||
|
model = AzureAIStudioLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="gpt-35-turbo",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||||
|
stream=False,
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert len(result.message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_stream_model(setup_azure_ai_studio_mock):
|
||||||
|
model = AzureAIStudioLargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="gpt-35-turbo",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||||
|
stream=True,
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Generator)
|
||||||
|
|
||||||
|
for chunk in result:
|
||||||
|
assert isinstance(chunk, LLMResultChunk)
|
||||||
|
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||||
|
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||||
|
if chunk.delta.finish_reason is not None:
|
||||||
|
assert chunk.delta.usage is not None
|
||||||
|
assert chunk.delta.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = AzureAIStudioLargeLanguageModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="gpt-35-turbo",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 21
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.azure_ai_studio.azure_ai_studio import AzureAIStudioProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_provider_credentials():
|
||||||
|
provider = AzureAIStudioProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(credentials={})
|
||||||
|
|
||||||
|
provider.validate_provider_credentials(
|
||||||
|
credentials={"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}
|
||||||
|
)
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = AzureAIStudioRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="azure-ai-studio-rerank-v1",
|
||||||
|
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = AzureAIStudioRerankModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="azure-ai-studio-rerank-v1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.getenv("AZURE_AI_STUDIO_JWT_TOKEN"),
|
||||||
|
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||||
|
},
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 1
|
||||||
|
assert result.docs[0].index == 1
|
||||||
|
assert result.docs[0].score >= 0.8
|
||||||
Loading…
Reference in New Issue