|
|
|
|
@ -32,6 +32,9 @@ from core.model_runtime.entities.message_entities import (
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
DEFAULT_V2_ENDPOINT = "maas-api.ml-platform-cn-beijing.volces.com"
|
|
|
|
|
DEFAULT_V3_ENDPOINT = "https://ark.cn-beijing.volces.com/api/v3"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArkClientV3:
|
|
|
|
|
endpoint_id: Optional[str] = None
|
|
|
|
|
@ -43,16 +46,24 @@ class ArkClientV3:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def is_legacy(credentials: dict) -> bool:
|
|
|
|
|
# match default v2 endpoint
|
|
|
|
|
if ArkClientV3.is_compatible_with_legacy(credentials):
|
|
|
|
|
return False
|
|
|
|
|
sdk_version = credentials.get("sdk_version", "v2")
|
|
|
|
|
return sdk_version != "v3"
|
|
|
|
|
# match default v3 endpoint
|
|
|
|
|
if credentials.get("api_endpoint_host") == DEFAULT_V3_ENDPOINT:
|
|
|
|
|
return False
|
|
|
|
|
# only v3 support api_key
|
|
|
|
|
if credentials.get("auth_method") == "api_key":
|
|
|
|
|
return False
|
|
|
|
|
# these cases are considered as sdk v2
|
|
|
|
|
# - modified default v2 endpoint
|
|
|
|
|
# - modified default v3 endpoint and auth without api_key
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def is_compatible_with_legacy(credentials: dict) -> bool:
|
|
|
|
|
sdk_version = credentials.get("sdk_version")
|
|
|
|
|
endpoint = credentials.get("api_endpoint_host")
|
|
|
|
|
return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com"
|
|
|
|
|
return endpoint == DEFAULT_V2_ENDPOINT
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_credentials(cls, credentials):
|
|
|
|
|
@ -64,7 +75,7 @@ class ArkClientV3:
|
|
|
|
|
"sk": credentials['volc_secret_access_key'],
|
|
|
|
|
}
|
|
|
|
|
if cls.is_compatible_with_legacy(credentials):
|
|
|
|
|
args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3"
|
|
|
|
|
args["base_url"] = DEFAULT_V3_ENDPOINT
|
|
|
|
|
|
|
|
|
|
client = ArkClientV3(
|
|
|
|
|
**args
|
|
|
|
|
|