feat: support more model types and builtin tools on aws/sagemaker (#8061)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>pull/8124/head
parent
ab7d79275e
commit
954580a4af
@ -0,0 +1,142 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import IO, Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||||
|
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SageMakerSpeech2TextModel(Speech2TextModel):
|
||||||
|
"""
|
||||||
|
Model class for Xinference speech to text model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
s3_client : Any = None
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
file: IO[bytes], user: Optional[str] = None) \
|
||||||
|
-> str:
|
||||||
|
"""
|
||||||
|
Invoke speech2text model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param file: audio file
|
||||||
|
:param user: unique user id
|
||||||
|
:return: text for given audio file
|
||||||
|
"""
|
||||||
|
asr_text = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('aws_access_key_id')
|
||||||
|
secret_key = credentials.get('aws_secret_access_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
self.s3_client = boto3.client("s3",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
self.s3_client = boto3.client("s3")
|
||||||
|
|
||||||
|
s3_prefix='dify/speech2text/'
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
bucket = credentials.get('audio_s3_cache_bucket')
|
||||||
|
|
||||||
|
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
|
||||||
|
payload = {
|
||||||
|
"audio_s3_presign_uri" : s3_presign_url
|
||||||
|
}
|
||||||
|
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=sagemaker_endpoint,
|
||||||
|
Body=json.dumps(payload),
|
||||||
|
ContentType="application/json"
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
asr_text = json_obj['text']
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Exception {e}, line : {line}')
|
||||||
|
|
||||||
|
return asr_text
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.SPEECH2TEXT,
|
||||||
|
model_properties={ },
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
@ -0,0 +1,287 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TTSModelType(Enum):
|
||||||
|
PresetVoice = "PresetVoice"
|
||||||
|
CloneVoice = "CloneVoice"
|
||||||
|
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||||
|
InstructVoice = "InstructVoice"
|
||||||
|
|
||||||
|
class SageMakerText2SpeechModel(TTSModel):
|
||||||
|
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
s3_client : Any = None
|
||||||
|
comprehend_client : Any = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# preset voices, need support custom voice
|
||||||
|
self.model_voices = {
|
||||||
|
'__default': {
|
||||||
|
'all': [
|
||||||
|
{'name': 'Default', 'value': 'default'},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'CosyVoice': {
|
||||||
|
'zh-Hans': [
|
||||||
|
{'name': '中文男', 'value': '中文男'},
|
||||||
|
{'name': '中文女', 'value': '中文女'},
|
||||||
|
{'name': '粤语女', 'value': '粤语女'},
|
||||||
|
],
|
||||||
|
'zh-Hant': [
|
||||||
|
{'name': '中文男', 'value': '中文男'},
|
||||||
|
{'name': '中文女', 'value': '中文女'},
|
||||||
|
{'name': '粤语女', 'value': '粤语女'},
|
||||||
|
],
|
||||||
|
'en-US': [
|
||||||
|
{'name': '英文男', 'value': '英文男'},
|
||||||
|
{'name': '英文女', 'value': '英文女'},
|
||||||
|
],
|
||||||
|
'ja-JP': [
|
||||||
|
{'name': '日语男', 'value': '日语男'},
|
||||||
|
],
|
||||||
|
'ko-KR': [
|
||||||
|
{'name': '韩语女', 'value': '韩语女'},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||||
|
map_dict = {
|
||||||
|
"zh" : "<|zh|>",
|
||||||
|
"en" : "<|en|>",
|
||||||
|
"ja" : "<|jp|>",
|
||||||
|
"zh-TW" : "<|yue|>",
|
||||||
|
"ko" : "<|ko|>"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||||
|
language_code = response['Languages'][0]['LanguageCode']
|
||||||
|
|
||||||
|
return map_dict.get(language_code, '<|zh|>')
|
||||||
|
|
||||||
|
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||||
|
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||||
|
return { "tts_text" : content_text, "role" : model_role }
|
||||||
|
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||||
|
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||||
|
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||||
|
lang_tag = self._detect_lang_code(content_text)
|
||||||
|
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||||
|
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||||
|
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||||
|
|
||||||
|
raise RuntimeError(f"Invalid params for {model_type}")
|
||||||
|
|
||||||
|
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||||
|
user: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
_invoke text2speech model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param tenant_id: user tenant id
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param voice: model timbre
|
||||||
|
:param content_text: text content to be translated
|
||||||
|
:param user: unique user id
|
||||||
|
:return: text translated to audio file
|
||||||
|
"""
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('aws_access_key_id')
|
||||||
|
secret_key = credentials.get('aws_secret_access_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
self.s3_client = boto3.client("s3",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
self.comprehend_client = boto3.client('comprehend',
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||||
|
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
self.s3_client = boto3.client("s3")
|
||||||
|
self.comprehend_client = boto3.client('comprehend')
|
||||||
|
|
||||||
|
model_type = credentials.get('audio_model_type', 'PresetVoice')
|
||||||
|
prompt_text = credentials.get('prompt_text')
|
||||||
|
prompt_audio = credentials.get('prompt_audio')
|
||||||
|
instruct_text = credentials.get('instruct_text')
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
payload = self._build_tts_payload(
|
||||||
|
model_type,
|
||||||
|
content_text,
|
||||||
|
voice,
|
||||||
|
prompt_text,
|
||||||
|
prompt_audio,
|
||||||
|
instruct_text
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.TTS,
|
||||||
|
model_properties={},
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||||
|
return 15
|
||||||
|
|
||||||
|
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||||
|
return "mp3"
|
||||||
|
|
||||||
|
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||||
|
return 5
|
||||||
|
|
||||||
|
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||||
|
audio_model_name = 'CosyVoice'
|
||||||
|
for key, voices in self.model_voices.items():
|
||||||
|
if key in audio_model_name:
|
||||||
|
if language and language in voices:
|
||||||
|
return voices[language]
|
||||||
|
elif 'all' in voices:
|
||||||
|
return voices['all']
|
||||||
|
|
||||||
|
return self.model_voices['__default']['all']
|
||||||
|
|
||||||
|
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=endpoint,
|
||||||
|
Body=json.dumps(payload),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
return json_obj
|
||||||
|
|
||||||
|
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
|
||||||
|
"""
|
||||||
|
_tts_invoke_streaming text2speech model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param content_text: text content to be translated
|
||||||
|
:param voice: model timbre
|
||||||
|
:return: text translated to audio file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
lang_tag = ''
|
||||||
|
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
|
||||||
|
lang_tag = payload.pop('lang_tag')
|
||||||
|
|
||||||
|
word_limit = self._get_model_word_limit(model='', credentials={})
|
||||||
|
content_text = payload.get("tts_text")
|
||||||
|
if len(content_text) > word_limit:
|
||||||
|
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||||
|
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
|
||||||
|
len_sent = len(sentences)
|
||||||
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
|
||||||
|
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
|
||||||
|
for idx in range(len_sent):
|
||||||
|
payloads[idx]["tts_text"] = sentences[idx]
|
||||||
|
|
||||||
|
futures = [ executor.submit(
|
||||||
|
self._invoke_sagemaker,
|
||||||
|
payload=payload,
|
||||||
|
endpoint=sagemaker_endpoint,
|
||||||
|
)
|
||||||
|
for payload in payloads]
|
||||||
|
|
||||||
|
for index, future in enumerate(futures):
|
||||||
|
resp = future.result()
|
||||||
|
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||||
|
for i in range(0, len(audio_bytes), 1024):
|
||||||
|
yield audio_bytes[i:i + 1024]
|
||||||
|
else:
|
||||||
|
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
|
||||||
|
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||||
|
|
||||||
|
for i in range(0, len(audio_bytes), 1024):
|
||||||
|
yield audio_bytes[i:i + 1024]
|
||||||
|
except Exception as ex:
|
||||||
|
raise InvokeBadRequestError(str(ex))
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaYamlToJsonTool(BuiltinTool):
|
||||||
|
lambda_client: Any = None
|
||||||
|
|
||||||
|
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||||
|
msg = {
|
||||||
|
"body": yaml_content
|
||||||
|
}
|
||||||
|
logger.info(json.dumps(msg))
|
||||||
|
|
||||||
|
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||||
|
InvocationType='RequestResponse',
|
||||||
|
Payload=json.dumps(msg))
|
||||||
|
response_body = invoke_response['Payload']
|
||||||
|
|
||||||
|
response_str = response_body.read().decode("utf-8")
|
||||||
|
resp_json = json.loads(response_str)
|
||||||
|
|
||||||
|
logger.info(resp_json)
|
||||||
|
if resp_json['statusCode'] != 200:
|
||||||
|
raise Exception(f"Invalid status code: {response_str}")
|
||||||
|
|
||||||
|
return resp_json['body']
|
||||||
|
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any],
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
invoke tools
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.lambda_client:
|
||||||
|
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
|
||||||
|
if aws_region:
|
||||||
|
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.lambda_client = boto3.client("lambda")
|
||||||
|
|
||||||
|
yaml_content = tool_parameters.get('yaml_content', '')
|
||||||
|
if not yaml_content:
|
||||||
|
return self.create_text_message('Please input yaml_content')
|
||||||
|
|
||||||
|
lambda_name = tool_parameters.get('lambda_name', '')
|
||||||
|
if not lambda_name:
|
||||||
|
return self.create_text_message('Please input lambda_name')
|
||||||
|
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
|
||||||
|
|
||||||
|
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||||
|
logger.debug(result)
|
||||||
|
|
||||||
|
return self.create_text_message(result)
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_text_message(f'Exception: {str(e)}')
|
||||||
|
|
||||||
|
console_handler.flush()
|
||||||
@ -0,0 +1,95 @@
|
|||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class TTSModelType(Enum):
|
||||||
|
PresetVoice = "PresetVoice"
|
||||||
|
CloneVoice = "CloneVoice"
|
||||||
|
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||||
|
InstructVoice = "InstructVoice"
|
||||||
|
|
||||||
|
class SageMakerTTSTool(BuiltinTool):
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
sagemaker_endpoint:str = None
|
||||||
|
s3_client : Any = None
|
||||||
|
comprehend_client : Any = None
|
||||||
|
|
||||||
|
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||||
|
map_dict = {
|
||||||
|
"zh" : "<|zh|>",
|
||||||
|
"en" : "<|en|>",
|
||||||
|
"ja" : "<|jp|>",
|
||||||
|
"zh-TW" : "<|yue|>",
|
||||||
|
"ko" : "<|ko|>"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||||
|
language_code = response['Languages'][0]['LanguageCode']
|
||||||
|
return map_dict.get(language_code, '<|zh|>')
|
||||||
|
|
||||||
|
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||||
|
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||||
|
return { "tts_text" : content_text, "role" : model_role }
|
||||||
|
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||||
|
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||||
|
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||||
|
lang_tag = self._detect_lang_code(content_text)
|
||||||
|
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||||
|
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||||
|
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||||
|
|
||||||
|
raise RuntimeError(f"Invalid params for {model_type}")
|
||||||
|
|
||||||
|
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=endpoint,
|
||||||
|
Body=json.dumps(payload),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
return json_obj
|
||||||
|
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any],
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
invoke tools
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
aws_region = tool_parameters.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||||
|
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
self.s3_client = boto3.client("s3")
|
||||||
|
self.comprehend_client = boto3.client('comprehend')
|
||||||
|
|
||||||
|
if not self.sagemaker_endpoint:
|
||||||
|
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||||
|
|
||||||
|
tts_text = tool_parameters.get('tts_text')
|
||||||
|
tts_infer_type = tool_parameters.get('tts_infer_type')
|
||||||
|
|
||||||
|
voice = tool_parameters.get('voice')
|
||||||
|
mock_voice_audio = tool_parameters.get('mock_voice_audio')
|
||||||
|
mock_voice_text = tool_parameters.get('mock_voice_text')
|
||||||
|
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
|
||||||
|
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
|
||||||
|
|
||||||
|
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||||
|
|
||||||
|
return self.create_text_message(text=result['s3_presign_url'])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_text_message(f'Exception {str(e)}')
|
||||||
@ -0,0 +1,149 @@
|
|||||||
|
identity:
|
||||||
|
name: sagemaker_tts
|
||||||
|
author: AWS
|
||||||
|
label:
|
||||||
|
en_US: SagemakerTTS
|
||||||
|
zh_Hans: Sagemaker语音合成
|
||||||
|
pt_BR: SagemakerTTS
|
||||||
|
icon: icon.svg
|
||||||
|
description:
|
||||||
|
human:
|
||||||
|
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
|
||||||
|
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||||
|
pt_BR: A tool for Speech synthesis.
|
||||||
|
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||||
|
parameters:
|
||||||
|
- name: sagemaker_endpoint
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: sagemaker endpoint for tts
|
||||||
|
zh_Hans: 语音生成的SageMaker端点
|
||||||
|
pt_BR: sagemaker endpoint for tts
|
||||||
|
human_description:
|
||||||
|
en_US: sagemaker endpoint for tts
|
||||||
|
zh_Hans: 语音生成的SageMaker端点
|
||||||
|
pt_BR: sagemaker endpoint for tts
|
||||||
|
llm_description: sagemaker endpoint for tts
|
||||||
|
form: form
|
||||||
|
- name: tts_text
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: tts text
|
||||||
|
zh_Hans: 语音合成原文
|
||||||
|
pt_BR: tts text
|
||||||
|
human_description:
|
||||||
|
en_US: tts text
|
||||||
|
zh_Hans: 语音合成原文
|
||||||
|
pt_BR: tts text
|
||||||
|
llm_description: tts text
|
||||||
|
form: llm
|
||||||
|
- name: tts_infer_type
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: tts infer type
|
||||||
|
zh_Hans: 合成方式
|
||||||
|
pt_BR: tts infer type
|
||||||
|
human_description:
|
||||||
|
en_US: tts infer type
|
||||||
|
zh_Hans: 合成方式
|
||||||
|
pt_BR: tts infer type
|
||||||
|
llm_description: tts infer type
|
||||||
|
options:
|
||||||
|
- value: PresetVoice
|
||||||
|
label:
|
||||||
|
en_US: preset voice
|
||||||
|
zh_Hans: 预置音色
|
||||||
|
- value: CloneVoice
|
||||||
|
label:
|
||||||
|
en_US: clone voice
|
||||||
|
zh_Hans: 克隆音色
|
||||||
|
- value: CloneVoice_CrossLingual
|
||||||
|
label:
|
||||||
|
en_US: clone crossLingual voice
|
||||||
|
zh_Hans: 克隆音色(跨语言)
|
||||||
|
- value: InstructVoice
|
||||||
|
label:
|
||||||
|
en_US: instruct voice
|
||||||
|
zh_Hans: 指令音色
|
||||||
|
form: form
|
||||||
|
- name: voice
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: preset voice
|
||||||
|
zh_Hans: 预置音色
|
||||||
|
pt_BR: preset voice
|
||||||
|
human_description:
|
||||||
|
en_US: preset voice
|
||||||
|
zh_Hans: 预置音色
|
||||||
|
pt_BR: preset voice
|
||||||
|
llm_description: preset voice
|
||||||
|
options:
|
||||||
|
- value: 中文男
|
||||||
|
label:
|
||||||
|
en_US: zh-cn male
|
||||||
|
zh_Hans: 中文男
|
||||||
|
- value: 中文女
|
||||||
|
label:
|
||||||
|
en_US: zh-cn female
|
||||||
|
zh_Hans: 中文女
|
||||||
|
- value: 粤语女
|
||||||
|
label:
|
||||||
|
en_US: zh-TW female
|
||||||
|
zh_Hans: 粤语女
|
||||||
|
form: form
|
||||||
|
- name: mock_voice_audio
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: clone voice link
|
||||||
|
zh_Hans: 克隆音频链接
|
||||||
|
pt_BR: clone voice link
|
||||||
|
human_description:
|
||||||
|
en_US: clone voice link
|
||||||
|
zh_Hans: 克隆音频链接
|
||||||
|
pt_BR: clone voice link
|
||||||
|
llm_description: clone voice link
|
||||||
|
form: llm
|
||||||
|
- name: mock_voice_text
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: text of clone voice
|
||||||
|
zh_Hans: 克隆音频对应文本
|
||||||
|
pt_BR: text of clone voice
|
||||||
|
human_description:
|
||||||
|
en_US: text of clone voice
|
||||||
|
zh_Hans: 克隆音频对应文本
|
||||||
|
pt_BR: text of clone voice
|
||||||
|
llm_description: text of clone voice
|
||||||
|
form: llm
|
||||||
|
- name: voice_instruct_prompt
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: instruct prompt for voice
|
||||||
|
zh_Hans: 音色指令文本
|
||||||
|
pt_BR: instruct prompt for voice
|
||||||
|
human_description:
|
||||||
|
en_US: instruct prompt for voice
|
||||||
|
zh_Hans: 音色指令文本
|
||||||
|
pt_BR: instruct prompt for voice
|
||||||
|
llm_description: instruct prompt for voice
|
||||||
|
form: llm
|
||||||
|
- name: aws_region
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: region of sagemaker endpoint
|
||||||
|
zh_Hans: SageMaker 端点所在的region
|
||||||
|
pt_BR: region of sagemaker endpoint
|
||||||
|
human_description:
|
||||||
|
en_US: region of sagemaker endpoint
|
||||||
|
zh_Hans: SageMaker 端点所在的region
|
||||||
|
pt_BR: region of sagemaker endpoint
|
||||||
|
llm_description: region of sagemaker endpoint
|
||||||
|
form: form
|
||||||
Loading…
Reference in New Issue