Merge branch 'feat/parent-child-retrieval' of https://github.com/langgenius/dify into feat/parent-child-retrieval
commit
78fff31e61
@ -0,0 +1,96 @@
|
||||
exclude = [
|
||||
"migrations/*",
|
||||
]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
quote-style = "double"
|
||||
|
||||
[lint]
|
||||
preview = true
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"PT", # flake8-pytest-style rules
|
||||
"PLC0208", # iteration-over-set
|
||||
"PLC2801", # unnecessary-dunder-call
|
||||
"PLC0414", # useless-import-alias
|
||||
"PLE0604", # invalid-all-object
|
||||
"PLE0605", # invalid-all-format
|
||||
"PLR0402", # manual-from-import
|
||||
"PLR1711", # useless-return
|
||||
"PLR1714", # repeated-equality-comparison
|
||||
"RUF013", # implicit-optional
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E721", # type-comparison
|
||||
"E722", # bare-except
|
||||
"E731", # lambda-assignment
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"FURB113", # repeated-append
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
"PT011", # pytest-raises-too-broad
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"F401", # unused-import
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
extend-generics = [
|
||||
"_pytest.monkeypatch",
|
||||
"tests.integration_tests",
|
||||
]
|
||||
@ -1,113 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
python_version = sys.version_info
|
||||
if not ((3, 11) <= python_version < (3, 13)):
|
||||
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
|
||||
raise SystemExit(1)
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
||||
import grpc.experimental.gevent
|
||||
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from flask import Response
|
||||
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils, version_utils
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers # noqa: F401
|
||||
from extensions.ext_database import db
|
||||
|
||||
# TODO: Find a way to avoid importing models here
|
||||
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
||||
os.environ["TZ"] = "UTC"
|
||||
# windows platform not support tzset
|
||||
if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if dify_config.TESTING:
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@app.route("/threads")
|
||||
def threads():
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
|
||||
thread_list = []
|
||||
for thread in threads:
|
||||
thread_name = thread.name
|
||||
thread_id = thread.ident
|
||||
is_alive = thread.is_alive()
|
||||
|
||||
thread_list.append(
|
||||
{
|
||||
"name": thread_name,
|
||||
"id": thread_id,
|
||||
"is_alive": is_alive,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"thread_num": num_threads,
|
||||
"threads": thread_list,
|
||||
}
|
||||
|
||||
|
||||
@app.route("/db-pool-stat")
|
||||
def pool_stat():
|
||||
engine = db.engine
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in_connections": engine.pool.checkedin(),
|
||||
"checked_out_connections": engine.pool.checkedout(),
|
||||
"overflow_connections": engine.pool.overflow(),
|
||||
"connection_timeout": engine.pool.timeout(),
|
||||
"recycle_time": db.engine.pool._recycle,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@ -0,0 +1,145 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI-compatible text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke TTS model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model voice/speaker
|
||||
:param user: unique user id
|
||||
:return: audio data as bytes iterator
|
||||
"""
|
||||
# Set up headers with authentication if provided
|
||||
headers = {}
|
||||
if api_key := credentials.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Construct endpoint URL
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
endpoint_url = urljoin(endpoint_url, "audio/speech")
|
||||
|
||||
# Get audio format from model properties
|
||||
audio_format = self._get_model_audio_type(model, credentials)
|
||||
|
||||
# Split text into chunks if needed based on word limit
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
sentences = self._split_text_into_sentences(content_text, word_limit)
|
||||
|
||||
for sentence in sentences:
|
||||
# Prepare request payload
|
||||
payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}
|
||||
|
||||
# Make POST request
|
||||
response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
|
||||
# Stream the audio data
|
||||
for chunk in response.iter_content(chunk_size=4096):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# Get default voice for validation
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
|
||||
# Test with a simple text
|
||||
next(
|
||||
self._invoke(
|
||||
model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
"""
|
||||
# Parse voices from comma-separated string
|
||||
voice_names = credentials.get("voices", "alloy").strip().split(",")
|
||||
voices = []
|
||||
|
||||
for voice in voice_names:
|
||||
voice = voice.strip()
|
||||
if not voice:
|
||||
continue
|
||||
|
||||
# Use en-US for all voices
|
||||
voices.append(
|
||||
{
|
||||
"name": voice,
|
||||
"mode": voice,
|
||||
"language": "en-US",
|
||||
}
|
||||
)
|
||||
|
||||
# If no voices provided or all voices were empty strings, use 'alloy' as default
|
||||
if not voices:
|
||||
voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]
|
||||
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
|
||||
ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
|
||||
ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
|
||||
ModelPropertyKey.VOICES: voices,
|
||||
},
|
||||
)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Override base get_tts_model_voices to handle customizable voices
|
||||
"""
|
||||
model_schema = self.get_customizable_model_schema(model, credentials)
|
||||
|
||||
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support voice")
|
||||
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
|
||||
# Always return all voices regardless of language
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
@ -1,4 +1,4 @@
|
||||
from .common import ChatRole
|
||||
from .maas import MaasError, MaasService
|
||||
|
||||
__all__ = ["MaasService", "ChatRole", "MaasError"]
|
||||
__all__ = ["ChatRole", "MaasError", "MaasService"]
|
||||
|
||||
@ -0,0 +1,32 @@
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def get_thread_messages_length(conversation_id: str) -> int:
|
||||
"""
|
||||
Get the number of thread messages based on the parent message id.
|
||||
"""
|
||||
# Fetch all messages related to the conversation
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.parent_message_id,
|
||||
Message.answer,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
messages = query.all()
|
||||
|
||||
# Extract thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# Exclude the newly created message with an empty answer
|
||||
if thread_messages and not thread_messages[0].answer:
|
||||
thread_messages.pop(0)
|
||||
|
||||
return len(thread_messages)
|
||||
@ -0,0 +1,46 @@
|
||||
identity:
|
||||
name: chinese_toxicity_detector
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Chinese Toxicity Detector
|
||||
zh_Hans: 中文有害内容检测
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to detect Chinese toxicity
|
||||
zh_Hans: 检测中文有害内容的工具
|
||||
llm: A tool that checks if Chinese content is safe for work
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for moderation
|
||||
zh_Hans: 内容审核的SageMaker端点
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for content moderation
|
||||
zh_Hans: 内容审核的SageMaker端点
|
||||
llm_description: sagemaker endpoint for content moderation
|
||||
form: form
|
||||
- name: content_text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: content text
|
||||
zh_Hans: 待审核文本
|
||||
human_description:
|
||||
en_US: text content to be moderated
|
||||
zh_Hans: 需要审核的文本内容
|
||||
llm_description: text content to be moderated
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
@ -0,0 +1,418 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.exceptions import ClientError
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LanguageCodeOptions = [
|
||||
"af-ZA",
|
||||
"ar-AE",
|
||||
"ar-SA",
|
||||
"da-DK",
|
||||
"de-CH",
|
||||
"de-DE",
|
||||
"en-AB",
|
||||
"en-AU",
|
||||
"en-GB",
|
||||
"en-IE",
|
||||
"en-IN",
|
||||
"en-US",
|
||||
"en-WL",
|
||||
"es-ES",
|
||||
"es-US",
|
||||
"fa-IR",
|
||||
"fr-CA",
|
||||
"fr-FR",
|
||||
"he-IL",
|
||||
"hi-IN",
|
||||
"id-ID",
|
||||
"it-IT",
|
||||
"ja-JP",
|
||||
"ko-KR",
|
||||
"ms-MY",
|
||||
"nl-NL",
|
||||
"pt-BR",
|
||||
"pt-PT",
|
||||
"ru-RU",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"tr-TR",
|
||||
"zh-CN",
|
||||
"zh-TW",
|
||||
"th-TH",
|
||||
"en-ZA",
|
||||
"en-NZ",
|
||||
"vi-VN",
|
||||
"sv-SE",
|
||||
"ab-GE",
|
||||
"ast-ES",
|
||||
"az-AZ",
|
||||
"ba-RU",
|
||||
"be-BY",
|
||||
"bg-BG",
|
||||
"bn-IN",
|
||||
"bs-BA",
|
||||
"ca-ES",
|
||||
"ckb-IQ",
|
||||
"ckb-IR",
|
||||
"cs-CZ",
|
||||
"cy-WL",
|
||||
"el-GR",
|
||||
"et-ET",
|
||||
"eu-ES",
|
||||
"fi-FI",
|
||||
"gl-ES",
|
||||
"gu-IN",
|
||||
"ha-NG",
|
||||
"hr-HR",
|
||||
"hu-HU",
|
||||
"hy-AM",
|
||||
"is-IS",
|
||||
"ka-GE",
|
||||
"kab-DZ",
|
||||
"kk-KZ",
|
||||
"kn-IN",
|
||||
"ky-KG",
|
||||
"lg-IN",
|
||||
"lt-LT",
|
||||
"lv-LV",
|
||||
"mhr-RU",
|
||||
"mi-NZ",
|
||||
"mk-MK",
|
||||
"ml-IN",
|
||||
"mn-MN",
|
||||
"mr-IN",
|
||||
"mt-MT",
|
||||
"no-NO",
|
||||
"or-IN",
|
||||
"pa-IN",
|
||||
"pl-PL",
|
||||
"ps-AF",
|
||||
"ro-RO",
|
||||
"rw-RW",
|
||||
"si-LK",
|
||||
"sk-SK",
|
||||
"sl-SI",
|
||||
"so-SO",
|
||||
"sr-RS",
|
||||
"su-ID",
|
||||
"sw-BI",
|
||||
"sw-KE",
|
||||
"sw-RW",
|
||||
"sw-TZ",
|
||||
"sw-UG",
|
||||
"tl-PH",
|
||||
"tt-RU",
|
||||
"ug-CN",
|
||||
"uk-UA",
|
||||
"uz-UZ",
|
||||
"wo-SN",
|
||||
"zu-ZA",
|
||||
]
|
||||
|
||||
MediaFormat = ["mp3", "mp4", "wav", "flac", "ogg", "amr", "webm", "m4a"]
|
||||
|
||||
|
||||
def is_url(text):
|
||||
if not text:
|
||||
return False
|
||||
text = text.strip()
|
||||
# Regular expression pattern for URL validation
|
||||
pattern = re.compile(
|
||||
r"^" # Start of the string
|
||||
r"(?:http|https)://" # Protocol (http or https)
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # Domain
|
||||
r"localhost|" # localhost
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address
|
||||
r"(?::\d+)?" # Optional port
|
||||
r"(?:/?|[/?]\S+)" # Path
|
||||
r"$", # End of the string
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return bool(pattern.match(text))
|
||||
|
||||
|
||||
def upload_file_from_url_to_s3(s3_client, url, bucket_name, s3_key=None, max_retries=3):
|
||||
"""
|
||||
Upload a file from a URL to an S3 bucket with retries and better error handling.
|
||||
|
||||
Parameters:
|
||||
- s3_client
|
||||
- url (str): The URL of the file to upload
|
||||
- bucket_name (str): The name of the S3 bucket
|
||||
- s3_key (str): The desired key (path) in S3. If None, will use the filename from URL
|
||||
- max_retries (int): Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
- tuple: (bool, str) - (Success status, Message)
|
||||
"""
|
||||
|
||||
# Validate inputs
|
||||
if not url or not bucket_name:
|
||||
return False, "URL and bucket name are required"
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Download the file from URL
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# If s3_key is not provided, try to get filename from URL
|
||||
if not s3_key:
|
||||
parsed_url = urlparse(url)
|
||||
filename = os.path.basename(parsed_url.path.split("/file-preview")[0])
|
||||
s3_key = "transcribe-files/" + filename
|
||||
|
||||
# Upload the file to S3
|
||||
s3_client.upload_fileobj(
|
||||
response.raw,
|
||||
bucket_name,
|
||||
s3_key,
|
||||
ExtraArgs={
|
||||
"ContentType": response.headers.get("content-type"),
|
||||
"ACL": "private", # Ensure the uploaded file is private
|
||||
},
|
||||
)
|
||||
|
||||
return f"s3://{bucket_name}/{s3_key}", f"Successfully uploaded file to s3://{bucket_name}/{s3_key}"
|
||||
|
||||
except RequestException as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
return None, f"Failed to download file from URL after {max_retries} attempts: {str(e)}"
|
||||
continue
|
||||
|
||||
except ClientError as e:
|
||||
return None, f"AWS S3 error: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
return None, f"Unexpected error: {str(e)}"
|
||||
|
||||
return None, "Maximum retries exceeded"
|
||||
|
||||
|
||||
class TranscribeTool(BuiltinTool):
|
||||
s3_client: Any = None
|
||||
transcribe_client: Any = None
|
||||
|
||||
"""
|
||||
Note that you must include one of LanguageCode, IdentifyLanguage,
|
||||
or IdentifyMultipleLanguages in your request.
|
||||
If you include more than one of these parameters, your transcription job fails.
|
||||
"""
|
||||
|
||||
def _transcribe_audio(self, audio_file_uri, file_type, **extra_args):
|
||||
uuid_str = str(uuid.uuid4())
|
||||
job_name = f"{int(time.time())}-{uuid_str}"
|
||||
try:
|
||||
# Start transcription job
|
||||
response = self.transcribe_client.start_transcription_job(
|
||||
TranscriptionJobName=job_name, Media={"MediaFileUri": audio_file_uri}, **extra_args
|
||||
)
|
||||
|
||||
# Wait for the job to complete
|
||||
while True:
|
||||
status = self.transcribe_client.get_transcription_job(TranscriptionJobName=job_name)
|
||||
if status["TranscriptionJob"]["TranscriptionJobStatus"] in ["COMPLETED", "FAILED"]:
|
||||
break
|
||||
time.sleep(5)
|
||||
|
||||
if status["TranscriptionJob"]["TranscriptionJobStatus"] == "COMPLETED":
|
||||
return status["TranscriptionJob"]["Transcript"]["TranscriptFileUri"], None
|
||||
else:
|
||||
return None, f"Error: TranscriptionJobStatus:{status['TranscriptionJob']['TranscriptionJobStatus']} "
|
||||
|
||||
except Exception as e:
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
def _download_and_read_transcript(self, transcript_file_uri: str, max_retries: int = 3) -> tuple[str, str]:
|
||||
"""
|
||||
Download and read the transcript file from the given URI.
|
||||
|
||||
Parameters:
|
||||
- transcript_file_uri (str): The URI of the transcript file
|
||||
- max_retries (int): Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
- tuple: (text, error) - (Transcribed text if successful, error message if failed)
|
||||
"""
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Download the transcript file
|
||||
response = requests.get(transcript_file_uri, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse the JSON content
|
||||
transcript_data = response.json()
|
||||
|
||||
# Check if speaker labels are present and enabled
|
||||
has_speaker_labels = (
|
||||
"results" in transcript_data
|
||||
and "speaker_labels" in transcript_data["results"]
|
||||
and "segments" in transcript_data["results"]["speaker_labels"]
|
||||
)
|
||||
|
||||
if has_speaker_labels:
|
||||
# Get speaker segments
|
||||
segments = transcript_data["results"]["speaker_labels"]["segments"]
|
||||
items = transcript_data["results"]["items"]
|
||||
|
||||
# Create a mapping of start_time -> speaker_label
|
||||
time_to_speaker = {}
|
||||
for segment in segments:
|
||||
speaker_label = segment["speaker_label"]
|
||||
for item in segment["items"]:
|
||||
time_to_speaker[item["start_time"]] = speaker_label
|
||||
|
||||
# Build transcript with speaker labels
|
||||
current_speaker = None
|
||||
transcript_parts = []
|
||||
|
||||
for item in items:
|
||||
# Skip non-pronunciation items (like punctuation)
|
||||
if item["type"] == "punctuation":
|
||||
transcript_parts.append(item["alternatives"][0]["content"])
|
||||
continue
|
||||
|
||||
start_time = item["start_time"]
|
||||
speaker = time_to_speaker.get(start_time)
|
||||
|
||||
if speaker != current_speaker:
|
||||
current_speaker = speaker
|
||||
transcript_parts.append(f"\n[{speaker}]: ")
|
||||
|
||||
transcript_parts.append(item["alternatives"][0]["content"])
|
||||
|
||||
return " ".join(transcript_parts).strip(), None
|
||||
else:
|
||||
# Extract the transcription text
|
||||
# The transcript text is typically in the 'results' -> 'transcripts' array
|
||||
if "results" in transcript_data and "transcripts" in transcript_data["results"]:
|
||||
transcripts = transcript_data["results"]["transcripts"]
|
||||
if transcripts:
|
||||
# Combine all transcript segments
|
||||
full_text = " ".join(t.get("transcript", "") for t in transcripts)
|
||||
return full_text, None
|
||||
|
||||
return None, "No transcripts found in the response"
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
return None, f"Failed to download transcript file after {max_retries} attempts: {str(e)}"
|
||||
continue
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return None, f"Failed to parse transcript JSON: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
return None, f"Unexpected error while processing transcript: {str(e)}"
|
||||
|
||||
return None, "Maximum retries exceeded"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.transcribe_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.transcribe_client = boto3.client("transcribe", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.transcribe_client = boto3.client("transcribe")
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
file_url = tool_parameters.get("file_url")
|
||||
file_type = tool_parameters.get("file_type")
|
||||
language_code = tool_parameters.get("language_code")
|
||||
identify_language = tool_parameters.get("identify_language", True)
|
||||
identify_multiple_languages = tool_parameters.get("identify_multiple_languages", False)
|
||||
language_options_str = tool_parameters.get("language_options")
|
||||
s3_bucket_name = tool_parameters.get("s3_bucket_name")
|
||||
ShowSpeakerLabels = tool_parameters.get("ShowSpeakerLabels", True)
|
||||
MaxSpeakerLabels = tool_parameters.get("MaxSpeakerLabels", 2)
|
||||
|
||||
# Check the input params
|
||||
if not s3_bucket_name:
|
||||
return self.create_text_message(text="s3_bucket_name is required")
|
||||
language_options = None
|
||||
if language_options_str:
|
||||
language_options = language_options_str.split("|")
|
||||
for lang in language_options:
|
||||
if lang not in LanguageCodeOptions:
|
||||
return self.create_text_message(
|
||||
text=f"{lang} is not supported, should be one of {LanguageCodeOptions}"
|
||||
)
|
||||
if language_code and language_code not in LanguageCodeOptions:
|
||||
err_msg = f"language_code:{language_code} is not supported, should be one of {LanguageCodeOptions}"
|
||||
return self.create_text_message(text=err_msg)
|
||||
|
||||
err_msg = f"identify_language:{identify_language}, \
|
||||
identify_multiple_languages:{identify_multiple_languages}, \
|
||||
Note that you must include one of LanguageCode, IdentifyLanguage, \
|
||||
or IdentifyMultipleLanguages in your request. \
|
||||
If you include more than one of these parameters, \
|
||||
your transcription job fails."
|
||||
if not language_code:
|
||||
if identify_language and identify_multiple_languages:
|
||||
return self.create_text_message(text=err_msg)
|
||||
else:
|
||||
if identify_language or identify_multiple_languages:
|
||||
return self.create_text_message(text=err_msg)
|
||||
|
||||
extra_args = {
|
||||
"IdentifyLanguage": identify_language,
|
||||
"IdentifyMultipleLanguages": identify_multiple_languages,
|
||||
}
|
||||
if language_code:
|
||||
extra_args["LanguageCode"] = language_code
|
||||
if language_options:
|
||||
extra_args["LanguageOptions"] = language_options
|
||||
if ShowSpeakerLabels:
|
||||
extra_args["Settings"] = {"ShowSpeakerLabels": ShowSpeakerLabels, "MaxSpeakerLabels": MaxSpeakerLabels}
|
||||
|
||||
# upload to s3 bucket
|
||||
s3_path_result, error = upload_file_from_url_to_s3(self.s3_client, url=file_url, bucket_name=s3_bucket_name)
|
||||
if not s3_path_result:
|
||||
return self.create_text_message(text=error)
|
||||
|
||||
transcript_file_uri, error = self._transcribe_audio(
|
||||
audio_file_uri=s3_path_result,
|
||||
file_type=file_type,
|
||||
**extra_args,
|
||||
)
|
||||
if not transcript_file_uri:
|
||||
return self.create_text_message(text=error)
|
||||
|
||||
# Download and read the transcript
|
||||
transcript_text, error = self._download_and_read_transcript(transcript_file_uri)
|
||||
if not transcript_text:
|
||||
return self.create_text_message(text=error)
|
||||
|
||||
return self.create_text_message(text=transcript_text)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}")
|
||||
@ -0,0 +1,133 @@
|
||||
identity:
|
||||
name: transcribe_asr
|
||||
author: AWS
|
||||
label:
|
||||
en_US: TranscribeASR
|
||||
zh_Hans: Transcribe语音识别转录
|
||||
pt_BR: TranscribeASR
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for ASR (Automatic Speech Recognition) - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: AWS 语音识别转录服务, 请参考 https://aws.amazon.com/cn/pm/transcribe/#Learn_More_About_Amazon_Transcribe
|
||||
pt_BR: A tool for ASR (Automatic Speech Recognition).
|
||||
llm: A tool for ASR (Automatic Speech Recognition).
|
||||
parameters:
|
||||
- name: file_url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: video or audio file url for transcribe
|
||||
zh_Hans: 语音或者视频文件url
|
||||
pt_BR: video or audio file url for transcribe
|
||||
human_description:
|
||||
en_US: video or audio file url for transcribe
|
||||
zh_Hans: 语音或者视频文件url
|
||||
pt_BR: video or audio file url for transcribe
|
||||
llm_description: video or audio file url for transcribe
|
||||
form: llm
|
||||
- name: language_code
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Language Code
|
||||
zh_Hans: 语言编码
|
||||
pt_BR: Language Code
|
||||
human_description:
|
||||
en_US: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
|
||||
zh_Hans: 语言编码,例如zh-CN, en-US 可参考 https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
|
||||
pt_BR: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
|
||||
llm_description: The language code used to create your transcription job.
|
||||
form: llm
|
||||
- name: identify_language
|
||||
type: boolean
|
||||
default: true
|
||||
required: false
|
||||
label:
|
||||
en_US: Automactically Identify Language
|
||||
zh_Hans: 自动识别语言
|
||||
pt_BR: Automactically Identify Language
|
||||
human_description:
|
||||
en_US: Automactically Identify Language
|
||||
zh_Hans: 自动识别语言
|
||||
pt_BR: Automactically Identify Language
|
||||
llm_description: Enable Automactically Identify Language
|
||||
form: form
|
||||
- name: identify_multiple_languages
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Automactically Identify Multiple Languages
|
||||
zh_Hans: 自动识别多种语言
|
||||
pt_BR: Automactically Identify Multiple Languages
|
||||
human_description:
|
||||
en_US: Automactically Identify Multiple Languages
|
||||
zh_Hans: 自动识别多种语言
|
||||
pt_BR: Automactically Identify Multiple Languages
|
||||
llm_description: Enable Automactically Identify Multiple Languages
|
||||
form: form
|
||||
- name: language_options
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Language Options
|
||||
zh_Hans: 语言种类选项
|
||||
pt_BR: Language Options
|
||||
human_description:
|
||||
en_US: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media
|
||||
zh_Hans: 您可以指定两个或更多的语言代码来表示您认为可能出现在媒体中的语言。用|分隔,如 zh-CN|en-US
|
||||
pt_BR: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media
|
||||
llm_description: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media
|
||||
form: llm
|
||||
- name: s3_bucket_name
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: s3 bucket name
|
||||
zh_Hans: s3 存储桶名称
|
||||
pt_BR: s3 bucket name
|
||||
human_description:
|
||||
en_US: s3 bucket name to store transcribe files (don't add prefix s3://)
|
||||
zh_Hans: s3 存储桶名称,用于存储转录文件 (不需要前缀 s3://)
|
||||
pt_BR: s3 bucket name to store transcribe files (don't add prefix s3://)
|
||||
llm_description: s3 bucket name to store transcribe files
|
||||
form: form
|
||||
- name: ShowSpeakerLabels
|
||||
type: boolean
|
||||
required: true
|
||||
default: true
|
||||
label:
|
||||
en_US: ShowSpeakerLabels
|
||||
zh_Hans: 显示说话人标签
|
||||
pt_BR: ShowSpeakerLabels
|
||||
human_description:
|
||||
en_US: Enables speaker partitioning (diarization) in your transcription output
|
||||
zh_Hans: 在转录输出中启用说话人分区(说话人分离)
|
||||
pt_BR: Enables speaker partitioning (diarization) in your transcription output
|
||||
llm_description: Enables speaker partitioning (diarization) in your transcription output
|
||||
form: form
|
||||
- name: MaxSpeakerLabels
|
||||
type: number
|
||||
required: true
|
||||
default: 2
|
||||
label:
|
||||
en_US: MaxSpeakerLabels
|
||||
zh_Hans: 说话人标签数量
|
||||
pt_BR: MaxSpeakerLabels
|
||||
human_description:
|
||||
en_US: Specify the maximum number of speakers you want to partition in your media
|
||||
zh_Hans: 指定您希望在媒体中划分的最多演讲者数量。
|
||||
pt_BR: Specify the maximum number of speakers you want to partition in your media
|
||||
llm_description: Specify the maximum number of speakers you want to partition in your media
|
||||
form: form
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: Please enter the AWS region for the transcribe service, for example 'us-east-1'.
|
||||
zh_Hans: 请输入Transcribe的 AWS 区域,例如 'us-east-1'。
|
||||
llm_description: Please enter the AWS region for the transcribe service, for example 'us-east-1'.
|
||||
form: form
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue