merge main

pull/12372/head
Joel 2 years ago
commit 99ffe43e91

@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD= REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# redis Cluster configuration.
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# PostgreSQL database configuration # PostgreSQL database configuration
DB_USERNAME=postgres DB_USERNAME=postgres
DB_PASSWORD=difyai123456 DB_PASSWORD=difyai123456
@ -234,6 +239,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# OpenSearch configuration # OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1 OPENSEARCH_HOST=127.0.0.1

@ -589,7 +589,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green")) click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e: except Exception as e:
logging.exception(f"Database migration failed: {e}") logging.exception("Failed to execute database migration")
finally: finally:
lock.release() lock.release()
else: else:
@ -633,7 +633,7 @@ where sites.id is null limit 1000"""
except Exception as e: except Exception as e:
failed_app_ids.append(app_id) failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}") logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
continue continue
if not processed_count: if not processed_count:

@ -616,6 +616,11 @@ class DataSetConfig(BaseSettings):
default=False, default=False,
) )
PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field(
description="Interval in days for message cleanup operations - plan: sandbox",
default=30,
)
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """

@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
description="Socket timeout in seconds for Redis Sentinel connections", description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1, default=0.1,
) )
REDIS_USE_CLUSTERS: bool = Field(
description="Enable Redis Clusters mode for high availability",
default=False,
)
REDIS_CLUSTERS: Optional[str] = Field(
description="Comma-separated list of Redis Clusters nodes (host:port)",
default=None,
)
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
description="Password for Redis Clusters authentication (if required)",
default=None,
)

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PositiveInt
class AnalyticdbConfig(BaseModel): class AnalyticdbConfig(BaseModel):
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
description="The password for accessing the specified namespace within the AnalyticDB instance" description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).", " (if namespace feature is enabled).",
) )
ANALYTICDB_HOST: Optional[str] = Field(
default=None, description="The host of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_PORT: PositiveInt = Field(
default=5432, description="The port of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.11.1", default="0.11.2",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -9,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
enterprise_license_required,
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
@ -28,6 +29,7 @@ class AppListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required
def get(self): def get(self):
"""Get app list""" """Get app list"""
@ -149,6 +151,7 @@ class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def get(self, app_model): def get(self, app_model):

@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception(f"internal server error, {str(e)}.") logging.exception("Failed to handle post request to ChatMessageAudioApi")
raise InternalServerError() raise InternalServerError()
@ -128,7 +128,7 @@ class ChatMessageTextApi(Resource):
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception(f"internal server error, {str(e)}.") logging.exception("Failed to handle post request to ChatMessageTextApi")
raise InternalServerError() raise InternalServerError()
@ -170,7 +170,7 @@ class TextModesApi(Resource):
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception(f"internal server error, {str(e)}.") logging.exception("Failed to handle get request to TextModesApi")
raise InternalServerError() raise InternalServerError()

@ -12,7 +12,7 @@ from controllers.console.auth.error import (
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
) )
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
@ -48,7 +48,7 @@ class ForgotPasswordSendEmailApi(Resource):
token = AccountService.send_reset_password_email(email=args["email"], language=language) token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"} return {"result": "fail", "data": token, "code": "account_not_found"}
else: else:
raise NotAllowedRegister() raise AccountNotFound()
else: else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)

@ -16,9 +16,9 @@ from controllers.console.auth.error import (
) )
from controllers.console.error import ( from controllers.console.error import (
AccountBannedError, AccountBannedError,
AccountNotFound,
EmailSendIpLimitError, EmailSendIpLimitError,
NotAllowedCreateWorkspace, NotAllowedCreateWorkspace,
NotAllowedRegister,
) )
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
@ -76,7 +76,7 @@ class LoginApi(Resource):
token = AccountService.send_reset_password_email(email=args["email"], language=language) token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"} return {"result": "fail", "data": token, "code": "account_not_found"}
else: else:
raise NotAllowedRegister() raise AccountNotFound()
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0: if len(tenants) == 0:
@ -119,7 +119,7 @@ class ResetPasswordSendEmailApi(Resource):
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language) token = AccountService.send_reset_password_email(email=args["email"], language=language)
else: else:
raise NotAllowedRegister() raise AccountNotFound()
else: else:
token = AccountService.send_reset_password_email(account=account, language=language) token = AccountService.send_reset_password_email(account=account, language=language)
@ -148,7 +148,7 @@ class EmailCodeLoginSendEmailApi(Resource):
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language) token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else: else:
raise NotAllowedRegister() raise AccountNotFound()
else: else:
token = AccountService.send_email_code_login_email(account=account, language=language) token = AccountService.send_email_code_login_email(account=account, language=language)

@ -10,7 +10,7 @@ from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -44,6 +44,7 @@ class DatasetListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required
def get(self): def get(self):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)

@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
retry_documents.append(document) retry_documents.append(document)
except Exception as e: except Exception as e:
logging.exception(f"Document {document_id} retry failed: {str(e)}") logging.exception(f"Failed to retry document, document id: {document_id}")
continue continue
# retry document # retry document
DocumentService.retry_document(dataset_id, retry_documents) DocumentService.retry_document(dataset_id, retry_documents)

@ -52,8 +52,8 @@ class AccountBannedError(BaseHTTPException):
code = 400 code = 400
class NotAllowedRegister(BaseHTTPException): class AccountNotFound(BaseHTTPException):
error_code = "unauthorized" error_code = "account_not_found"
description = "Account not found." description = "Account not found."
code = 400 code = 400
@ -86,3 +86,9 @@ class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded" error_code = "no_file_uploaded"
description = "Please upload your file." description = "Please upload your file."
code = 400 code = 400
class UnauthorizedAndForceLogout(BaseHTTPException):
error_code = "unauthorized_and_force_logout"
description = "Unauthorized and force logout."
code = 401

@ -45,7 +45,7 @@ class RemoteFileUploadApi(Resource):
resp = ssrf_proxy.head(url=url) resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK: if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3) resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp.raise_for_status() resp.raise_for_status()
file_info = helpers.guess_file_info_from_response(resp) file_info = helpers.guess_file_info_from_response(resp)

@ -14,7 +14,7 @@ from controllers.console.workspace.error import (
InvalidInvitationCodeError, InvalidInvitationCodeError,
RepeatPasswordNotMatchError, RepeatPasswordNotMatchError,
) )
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone from libs.helper import TimestampField, timezone
@ -79,6 +79,7 @@ class AccountProfileApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
@enterprise_license_required
def get(self): def get(self):
return current_user return current_user

@ -1,3 +1,5 @@
from urllib import parse
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse from flask_restful import Resource, abort, marshal_with, reqparse
@ -57,11 +59,12 @@ class MemberInviteEmailApi(Resource):
token = RegisterService.invite_new_member( token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
) )
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append( invitation_results.append(
{ {
"status": "success", "status": "success",
"email": invitee_email, "email": invitee_email,
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}", "url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
} }
) )
except AccountAlreadyInTenantError: except AccountAlreadyInTenantError:

@ -72,7 +72,10 @@ class DefaultModelApi(Resource):
model=model_setting["model"], model=model_setting["model"],
) )
except Exception as ex: except Exception as ex:
logging.exception(f"{model_setting['model_type']} save error: {ex}") logging.exception(
f"Failed to update default model, model type: {model_setting['model_type']},"
f" model:{model_setting.get('model')}"
)
raise ex raise ex
return {"result": "success"} return {"result": "success"}
@ -156,7 +159,10 @@ class ModelProviderModelApi(Resource):
credentials=args["credentials"], credentials=args["credentials"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
logging.exception(f"save model credentials error: {ex}") logging.exception(
f"Failed to save model credentials, tenant_id: {tenant_id},"
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
)
raise ValueError(str(ex)) raise ValueError(str(ex))
return {"result": "success"}, 200 return {"result": "success"}, 200

@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value from libs.helper import alphanumeric, uuid_value
from libs.login import login_required from libs.login import login_required
@ -549,6 +549,7 @@ class ToolLabelsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required
def get(self): def get(self):
return jsonable_encoder(ToolLabelsService.list_tool_labels()) return jsonable_encoder(ToolLabelsService.list_tool_labels())

@ -8,10 +8,10 @@ from flask_login import current_user
from configs import dify_config from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
def account_initialization_required(view): def account_initialization_required(view):
@ -142,3 +142,15 @@ def setup_required(view):
return view(*args, **kwargs) return view(*args, **kwargs)
return decorated return decorated
def enterprise_license_required(view):
@wraps(view)
def decorated(*args, **kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
return view(*args, **kwargs)
return decorated

@ -59,7 +59,7 @@ class AudioApi(WebApiResource):
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception(f"internal server error: {str(e)}") logging.exception("Failed to handle post request to AudioApi")
raise InternalServerError() raise InternalServerError()
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception(f"internal server error: {str(e)}") logging.exception("Failed to handle post request to TextApi")
raise InternalServerError() raise InternalServerError()

@ -16,9 +16,7 @@ class FileUploadConfigManager:
file_upload_dict = config.get("file_upload") file_upload_dict = config.get("file_upload")
if file_upload_dict: if file_upload_dict:
if file_upload_dict.get("enabled"): if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get( transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
"allowed_upload_methods", []
)
data = { data = {
"image_config": { "image_config": {
"number_limits": file_upload_dict["number_limits"], "number_limits": file_upload_dict["number_limits"],

@ -362,5 +362,5 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}")
raise e raise e

@ -242,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_listener_time = time.time() start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break break
if tts_publisher: if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)

@ -33,8 +33,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
), ),
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
@ -47,8 +47,8 @@ class BaseAppGenerator:
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
), ),
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
@ -91,6 +91,9 @@ class BaseAppGenerator:
) )
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
try: try:
if "." in value: if "." in value:

@ -80,7 +80,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(f"Failed to handle response, conversation_id: {conversation.id}")
raise e raise e
def _get_conversation_by_user( def _get_conversation_by_user(

@ -298,5 +298,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
)
raise e raise e

@ -216,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else: else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break break
if tts_publisher: if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)

@ -86,7 +86,7 @@ class MessageCycleManage:
conversation.name = name conversation.name = name
except Exception as e: except Exception as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logging.exception(f"generate conversation name failed: {e}") logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}")
pass pass
db.session.merge(conversation) db.session.merge(conversation)

@ -28,8 +28,8 @@ class FileUploadConfig(BaseModel):
image_config: Optional[ImageConfig] = None image_config: Optional[ImageConfig] = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list) allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0 number_limits: int = 0

@ -41,7 +41,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
if moderation_result is True: if moderation_result is True:
return True return True
except Exception as ex: except Exception as ex:
logger.exception(ex) logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
raise InvokeBadRequestError("Rate limit exceeded, please try again later.") raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
return False return False

@ -29,7 +29,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
spec.loader.exec_module(module) spec.loader.exec_module(module)
return module return module
except Exception as e: except Exception as e:
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}") logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'")
raise e raise e

@ -39,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
) )
retries = 0 retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
@ -52,6 +53,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:
if stream:
return response.iter_bytes()
return response return response
else: else:
logging.warning( logging.warning(

@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -500,11 +501,7 @@ class IndexingRunner:
document_node.metadata["doc_hash"] = hash document_node.metadata["doc_hash"] = hash
# delete Splitter character # delete Splitter character
page_content = document_node.page_content page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""): document_node.page_content = remove_leading_symbols(page_content)
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
if document_node.page_content: if document_node.page_content:
split_documents.append(document_node) split_documents.append(document_node)
@ -554,7 +551,7 @@ class IndexingRunner:
qa_documents.append(qa_document) qa_documents.append(qa_document)
format_documents.extend(qa_documents) format_documents.extend(qa_documents)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents) all_qa_documents.extend(format_documents)

@ -102,7 +102,7 @@ class LLMGenerator:
except InvokeError: except InvokeError:
questions = [] questions = []
except Exception as e: except Exception as e:
logging.exception(e) logging.exception("Failed to generate suggested questions after answer")
questions = [] questions = []
return questions return questions
@ -148,7 +148,7 @@ class LLMGenerator:
error = str(e) error = str(e)
error_step = "generate rule config" error_step = "generate rule config"
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
rule_config["error"] = str(e) rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -234,7 +234,7 @@ class LLMGenerator:
error_step = "generate conversation opener" error_step = "generate conversation opener"
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
rule_config["error"] = str(e) rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -286,7 +286,9 @@ class LLMGenerator:
error = str(e) error = str(e)
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(
f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}"
)
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod @classmethod

@ -325,14 +325,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
assistant_prompt_message.tool_calls.append(tool_call) assistant_prompt_message.tool_calls.append(tool_call)
# calculate num tokens # calculate num tokens
if response.usage: prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
# transform usage model, credentials, prompt_messages
prompt_tokens = response.usage.input_tokens )
completion_tokens = response.usage.output_tokens
else: completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
# calculate num tokens model, credentials, [assistant_prompt_message]
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) )
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

@ -103,7 +103,7 @@ class AzureRerankModel(RerankModel):
return RerankResult(model=model, docs=rerank_documents) return RerankResult(model=model, docs=rerank_documents)
except Exception as e: except Exception as e:
logger.exception(f"Exception in Azure rerank: {e}") logger.exception(f"Failed to invoke rerank model, model: {model}")
raise raise
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:

@ -2,13 +2,11 @@
import base64 import base64
import json import json
import logging import logging
import mimetypes
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
# 3rd import # 3rd import
import boto3 import boto3
import requests
from botocore.config import Config from botocore.config import Config
from botocore.exceptions import ( from botocore.exceptions import (
ClientError, ClientError,
@ -439,18 +437,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE: elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content) message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
url = message_content.data
image_content = requests.get(url).content
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,") data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "") mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1] base64_data = data_split[1]

@ -11,5 +11,6 @@
- gemini-1.5-flash-exp-0827 - gemini-1.5-flash-exp-0827
- gemini-1.5-flash-8b-exp-0827 - gemini-1.5-flash-8b-exp-0827
- gemini-1.5-flash-8b-exp-0924 - gemini-1.5-flash-8b-exp-0924
- gemini-exp-1114
- gemini-pro - gemini-pro
- gemini-pro-vision - gemini-pro-vision

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

@ -0,0 +1,38 @@
model: gemini-exp-1114
label:
en_US: Gemini exp 1114
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -32,3 +32,4 @@ pricing:
output: '0.00' output: '0.00'
unit: '0.000001' unit: '0.000001'
currency: USD currency: USD
deprecated: true

@ -36,3 +36,4 @@ pricing:
output: '0.00' output: '0.00'
unit: '0.000001' unit: '0.000001'
currency: USD currency: USD
deprecated: true

@ -1,7 +1,6 @@
import base64 import base64
import io import io
import json import json
import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke( def _invoke(
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
try: try:
ping_message = SystemPromptMessage(content="ping") ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
config_kwargs = model_parameters.copy() config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if schema := config_kwargs.pop("json_schema", None):
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"
if stop: if stop:
config_kwargs["stop_sequences"] = stop config_kwargs["stop_sequences"] = stop

@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials=credentials, credentials=credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=model_parameters, model_parameters=model_parameters,
tools=tools,
stop=stop, stop=stop,
stream=stream, stream=stream,
user=user, user=user,
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat") endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
if tools:
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
else: else:
endpoint_url = urljoin(endpoint_url, "api/generate") endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0] first_prompt_message = prompt_messages[0]
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if stream: if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
def _handle_generate_response( def _handle_generate_response(
self, self,
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_type: LLMMode, completion_type: LLMMode,
response: requests.Response, response: requests.Response,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]],
) -> LLMResult: ) -> LLMResult:
""" """
Handle llm completion response Handle llm completion response
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: llm result :return: llm result
""" """
response_json = response.json() response_json = response.json()
tool_calls = []
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
message = response_json.get("message", {}) message = response_json.get("message", {})
response_content = message.get("content", "") response_content = message.get("content", "")
response_tool_calls = message.get("tool_calls", [])
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
else: else:
response_content = response_json["response"] response_content = response_json["response"]
assistant_message = AssistantPromptMessage(content=response_content) assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
if "prompt_eval_count" in response_json and "eval_count" in response_json: if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage # transform usage
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
chunk_index += 1 chunk_index += 1
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
"""
Convert PromptMessageTool to dict for Ollama API
:param tool: tool
:return: tool dict
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
""" """
Convert PromptMessage to dict for Ollama API Convert PromptMessage to dict for Ollama API
:param message: prompt message
:return: message dict
""" """
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens return num_tokens
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
"""
Extract response tool call
"""
tool_call = None
if response_tool_call and "function" in response_tool_call:
# Convert arguments to JSON string if it's a dict
arguments = response_tool_call.get("function").get("arguments")
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function").get("name"),
arguments=arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("function").get("name"),
type="function",
function=function,
)
return tool_call
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
""" """
Get customizable model schema. Get customizable model schema.
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: model schema :return: model schema
""" """
extras = {} extras = {
"features": [],
}
if "vision_support" in credentials and credentials["vision_support"] == "true": if "vision_support" in credentials and credentials["vision_support"] == "true":
extras["features"] = [ModelFeature.VISION] extras["features"].append(ModelFeature.VISION)
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
extras["features"].append(ModelFeature.TOOL_CALL)
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,

@ -96,3 +96,22 @@ model_credential_schema:
label: label:
en_US: 'No' en_US: 'No'
zh_Hans: zh_Hans:
- variable: function_call_support
label:
zh_Hans: 是否支持函数调用
en_US: Function call support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
- value: 'false'
label:
en_US: 'No'
zh_Hans:

@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility # o1 compatibility
block_as_stream = False
if model.startswith("o1"): if model.startswith("o1"):
if "max_tokens" in model_parameters: if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"] del model_parameters["max_tokens"]
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs: if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"] del extra_model_kwargs["stop"]
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream: if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
def _handle_chat_generate_response( def _handle_chat_generate_response(
self, self,

@ -45,18 +45,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
self._update_credential(model, credentials) self._update_credential(model, credentials)
block_as_stream = False
if model.startswith("openai/o1"):
block_as_stream = True
stop = None
# invoke block as stream
if stream and block_as_stream:
return self._generate_block_as_stream(
model, credentials, prompt_messages, model_parameters, tools, stop, user
)
else:
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _generate_block_as_stream( def _generate_block_as_stream(
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> Generator: ) -> Generator:
resp: LLMResult = super()._generate( resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
)
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,

@ -113,7 +113,7 @@ class SageMakerRerankModel(RerankModel):
return RerankResult(model=model, docs=rerank_documents) return RerankResult(model=model, docs=rerank_documents)
except Exception as e: except Exception as e:
logger.exception(f"Exception {e}, line : {line}") logger.exception(f"Failed to invoke rerank model, model: {model}")
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """

@ -78,7 +78,7 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
json_obj = json.loads(json_str) json_obj = json.loads(json_str)
asr_text = json_obj["text"] asr_text = json_obj["text"]
except Exception as e: except Exception as e:
logger.exception(f"failed to invoke speech2text model, {e}") logger.exception(f"failed to invoke speech2text model, model: {model}")
raise CredentialsValidateFailedError(str(e)) raise CredentialsValidateFailedError(str(e))
return asr_text return asr_text

@ -117,7 +117,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model) return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model)
except Exception as e: except Exception as e:
logger.exception(f"Exception {e}, line : {line}") logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}")
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
""" """

@ -65,6 +65,8 @@ class GTERerankModel(RerankModel):
) )
rerank_documents = [] rerank_documents = []
if not response.output:
return RerankResult(model=model, docs=rerank_documents)
for _, result in enumerate(response.output.results): for _, result in enumerate(response.output.results):
# format document # format document
rerank_document = RerankDocument( rerank_document = RerankDocument(

@ -1,3 +1,6 @@
from collections.abc import Sequence
from typing import Any
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -62,5 +65,5 @@ class KeywordsModeration(Moderation):
def _is_violated(self, inputs: dict, keywords_list: list) -> bool: def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list, value) -> bool: def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
return any(keyword.lower() in value.lower() for keyword in keywords_list) return any(keyword.lower() in str(value).lower() for keyword in keywords_list)

@ -126,6 +126,6 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result return result
except Exception as e: except Exception as e:
logger.exception("Moderation Output error: %s", e) logger.exception(f"Moderation Output error, app_id: {app_id}")
return None return None

@ -49,6 +49,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
@field_validator("inputs", "outputs") @field_validator("inputs", "outputs")
@classmethod @classmethod

@ -25,7 +25,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunType, LangSmithRunType,
LangSmithRunUpdateModel, LangSmithRunUpdateModel,
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values, generate_dotted_order
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution from models.workflow import WorkflowNodeExecution
@ -62,6 +62,16 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info) self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
if trace_info.message_id: if trace_info.message_id:
message_run = LangSmithRunModel( message_run = LangSmithRunModel(
id=trace_info.message_id, id=trace_info.message_id,
@ -76,6 +86,8 @@ class LangSmithDataTrace(BaseTraceInstance):
}, },
tags=["message", "workflow"], tags=["message", "workflow"],
error=trace_info.error, error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
) )
self.add_run(message_run) self.add_run(message_run)
@ -95,6 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
error=trace_info.error, error=trace_info.error,
tags=["workflow"], tags=["workflow"],
parent_run_id=trace_info.message_id or None, parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
) )
self.add_run(langsmith_run) self.add_run(langsmith_run)
@ -177,6 +191,7 @@ class LangSmithDataTrace(BaseTraceInstance):
else: else:
run_type = LangSmithRunType.tool run_type = LangSmithRunType.tool
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel( langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens, total_tokens=node_total_tokens,
name=node_type, name=node_type,
@ -191,6 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
}, },
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
tags=["node_execution"], tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
) )
self.add_run(langsmith_run) self.add_run(langsmith_run)

@ -711,7 +711,7 @@ class TraceQueueManager:
trace_task.app_id = self.app_id trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task) trace_manager_queue.put(trace_task)
except Exception as e: except Exception as e:
logging.exception(f"Error adding trace task: {e}") logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
finally: finally:
self.start_timer() self.start_timer()
@ -730,7 +730,7 @@ class TraceQueueManager:
if tasks: if tasks:
self.send_to_celery(tasks) self.send_to_celery(tasks)
except Exception as e: except Exception as e:
logging.exception(f"Error processing trace tasks: {e}") logging.exception("Error processing trace tasks")
def start_timer(self): def start_timer(self):
global trace_manager_timer global trace_manager_timer

@ -1,5 +1,6 @@
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import Optional, Union
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message from models.model import Message
@ -43,3 +44,19 @@ def replace_text_with_content(data):
return [replace_text_with_content(item) for item in data] return [replace_text_with_content(item) for item in data]
else: else:
return data return data
def generate_dotted_order(
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
) -> str:
"""
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:
return current_segment
return f"{parent_dotted_order}.{current_segment}"

@ -1,310 +1,62 @@
import json import json
from typing import Any from typing import Any
from pydantic import BaseModel
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
class AnalyticdbConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = ("dify",)
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector): class AnalyticdbVector(BaseVector):
def __init__(self, collection_name: str, config: AnalyticdbConfig): def __init__(
self._collection_name = collection_name.lower() self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
try: ):
from alibabacloud_gpdb20160503.client import Client super().__init__(collection_name)
from alibabacloud_tea_openapi import models as open_api_models if api_config is not None:
except: self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_indexing_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
if redis_client.get(collection_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else: else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.ANALYTICDB return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0]) dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension) self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings) self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
for doc, embedding in zip(documents, embeddings, strict=True): self.analyticdb_vector.add_texts(texts, embeddings)
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models return self.analyticdb_vector.text_exists(id)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models self.analyticdb_vector.delete_by_ids(ids)
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models self.analyticdb_vector.delete_by_metadata_field(key, value)
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models return self.analyticdb_vector.search_by_vector(query_vector)
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models return self.analyticdb_vector.search_by_full_text(query, **kwargs)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def delete(self) -> None: def delete(self) -> None:
try: self.analyticdb_vector.delete()
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e
class AnalyticdbVectorFactory(AbstractVectorFactory): class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower() collection_name = class_prefix.lower()
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params if dify_config.ANALYTICDB_HOST is None:
if dify_config.ANALYTICDB_KEY_ID is None: # implemented through OpenAPI
raise ValueError("ANALYTICDB_KEY_ID should not be None") apiConfig = AnalyticdbVectorOpenAPIConfig(
if dify_config.ANALYTICDB_KEY_SECRET is None:
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
if dify_config.ANALYTICDB_REGION_ID is None:
raise ValueError("ANALYTICDB_REGION_ID should not be None")
if dify_config.ANALYTICDB_INSTANCE_ID is None:
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
if dify_config.ANALYTICDB_ACCOUNT is None:
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
if dify_config.ANALYTICDB_PASSWORD is None:
raise ValueError("ANALYTICDB_PASSWORD should not be None")
if dify_config.ANALYTICDB_NAMESPACE is None:
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID, access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID, region_id=dify_config.ANALYTICDB_REGION_ID,
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
account_password=dify_config.ANALYTICDB_PASSWORD, account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE, namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
), )
sqlConfig = None
else:
# implemented through sql
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
)
apiConfig = None
return AnalyticdbVector(
collection_name,
apiConfig,
sqlConfig,
) )

@ -0,0 +1,309 @@
import json
from typing import Any
from pydantic import BaseModel, model_validator
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
metrics: str = "cosine"
read_timeout: int = 60000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
if not values["region_id"]:
raise ValueError("config ANALYTICDB_REGION_ID is required")
if not values["instance_id"]:
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["namespace_password"]:
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e

@ -0,0 +1,245 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorBySqlConfig(BaseModel):
host: str
port: int
account: str
account_password: str
min_connection: int
max_connection: int
namespace: str = "dify"
metrics: str = "cosine"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
raise ValueError("config ANALYTICDB_PORT is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["min_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
return values
class AnalyticdbVectorBySql:
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
self._collection_name = collection_name.lower()
self.databaseName = "knowledgebase"
self.config = config
self.table_name = f"{self.config.namespace}.{self._collection_name}"
self.pool = None
self._initialize()
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.host}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _create_connection_pool(self):
return psycopg2.pool.SimpleConnectionPool(
self.config.min_connection,
self.config.max_connection,
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database=self.databaseName,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self) -> None:
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database="postgres",
)
conn.autocommit = True
cur = conn.cursor()
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" in str(e):
return
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
if "already exists" not in str(e):
raise e
cur.execute(
"CREATE OR REPLACE FUNCTION "
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
"AS words_only;$function$"
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
f"id text PRIMARY KEY,"
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
f"to_tsvector TSVECTOR"
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
)
if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx"
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
f"pq_enable=0, external_storage=0)"
)
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
except Exception as e:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
except Exception as e:
if "does not exist" not in str(e):
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
query_vector_str = "{" + query_vector_str[1:-1] + "}"
cur.execute(
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
try: try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(f"Failed to delete documents, ids: {ids}")
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
query = f""" query = f"""

@ -81,7 +81,7 @@ class LindormVectorStore(BaseVector):
"ids": batch_ids}, _source=False) "ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e: except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}: {e}") logger.exception(f"Error fetching batch {batch_ids}")
return set() return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) @retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
@ -99,7 +99,7 @@ class LindormVectorStore(BaseVector):
) )
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e: except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}: {e}") logger.exception(f"Error fetching batch ids: {batch_ids}")
return set() return set()
if ids is None: if ids is None:
@ -187,7 +187,7 @@ class LindormVectorStore(BaseVector):
logger.warning( logger.warning(
f"Index '{self._collection_name}' does not exist. No deletion performed.") f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e: except Exception as e:
logger.exception(f"Error occurred while deleting the index: {e}") logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e raise e
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
@ -213,7 +213,7 @@ class LindormVectorStore(BaseVector):
response = self._client.search( response = self._client.search(
index=self._collection_name, body=query) index=self._collection_name, body=query)
except Exception as e: except Exception as e:
logger.exception(f"Error executing search: {e}") logger.exception(f"Error executing vector search, query: {query}")
raise raise
docs_and_scores = [] docs_and_scores = []

@ -142,7 +142,7 @@ class MyScaleVector(BaseVector):
for r in self._client.query(sql).named_results() for r in self._client.query(sql).named_results()
] ]
except Exception as e: except Exception as e:
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") # noqa:TRY401
return [] return []
def delete(self) -> None: def delete(self) -> None:

@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector):
try: try:
response = self._client.search(index=self._collection_name.lower(), body=query) response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e: except Exception as e:
logger.exception(f"Error executing search: {e}") logger.exception(f"Error executing vector search, query: {query}")
raise raise
docs = [] docs = []

@ -69,7 +69,7 @@ class CacheEmbedding(Embeddings):
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
except Exception as e: except Exception as e:
logging.exception("Failed transform embedding: %s", e) logging.exception("Failed transform embedding")
cache_embeddings = [] cache_embeddings = []
try: try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback() db.session.rollback()
except Exception as ex: except Exception as ex:
db.session.rollback() db.session.rollback()
logger.exception("Failed to embed documents: %s", ex) logger.exception("Failed to embed documents: %s")
raise ex raise ex
return text_embeddings return text_embeddings
@ -112,7 +112,7 @@ class CacheEmbedding(Embeddings):
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex: except Exception as ex:
if dify_config.DEBUG: if dify_config.DEBUG:
logging.exception(f"Failed to embed query text: {ex}") logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
raise ex raise ex
try: try:
@ -126,7 +126,7 @@ class CacheEmbedding(Embeddings):
redis_client.setex(embedding_cache_key, 600, encoded_str) redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex: except Exception as ex:
if dify_config.DEBUG: if dify_config.DEBUG:
logging.exception("Failed to add embedding to redis %s", ex) logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
raise ex raise ex
return embedding_results return embedding_results

@ -229,7 +229,7 @@ class WordExtractor(BaseExtractor):
for i in url_pattern.findall(x.text): for i in url_pattern.findall(x.text):
hyperlinks_url = str(i) hyperlinks_url = str(i)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception("Failed to parse HYPERLINK xml")
def parse_paragraph(paragraph): def parse_paragraph(paragraph):
paragraph_content = [] paragraph_content = []

@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash document_node.metadata["doc_hash"] = hash
# delete Splitter character # delete Splitter character
page_content = document_node.page_content page_content = remove_leading_symbols(document_node.page_content).strip()
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
if len(page_content) > 0: if len(page_content) > 0:
document_node.page_content = page_content document_node.page_content = page_content
split_documents.append(document_node) split_documents.append(document_node)

@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_hash"] = hash document_node.metadata["doc_hash"] = hash
# delete Splitter character # delete Splitter character
page_content = document_node.page_content page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""): document_node.page_content = remove_leading_symbols(page_content)
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node) split_documents.append(document_node)
all_documents.extend(split_documents) all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10): for i in range(0, len(all_documents), 10):
@ -159,7 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor):
qa_documents.append(qa_document) qa_documents.append(qa_document)
format_documents.extend(qa_documents) format_documents.extend(qa_documents)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents) all_qa_documents.extend(format_documents)

@ -36,23 +36,21 @@ class WeightRerankRunner(BaseRerankRunner):
:return: :return:
""" """
docs = []
doc_id = []
unique_documents = [] unique_documents = []
doc_id = set()
for document in documents: for document in documents:
if document.metadata["doc_id"] not in doc_id: doc_id = document.metadata.get("doc_id")
doc_id.append(document.metadata["doc_id"]) if doc_id not in doc_id:
docs.append(document.page_content) doc_id.add(doc_id)
unique_documents.append(document) unique_documents.append(document)
documents = unique_documents documents = unique_documents
rerank_documents = []
query_scores = self._calculate_keyword_score(query, documents) query_scores = self._calculate_keyword_score(query, documents)
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
rerank_documents = []
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# format document
score = ( score = (
self.weights.vector_setting.vector_weight * query_vector_score self.weights.vector_setting.vector_weight * query_vector_score
+ self.weights.keyword_setting.keyword_weight * query_score + self.weights.keyword_setting.keyword_weight * query_score
@ -61,7 +59,8 @@ class WeightRerankRunner(BaseRerankRunner):
continue continue
document.metadata["score"] = score document.metadata["score"] = score
rerank_documents.append(document) rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:

@ -0,0 +1,87 @@
from typing import Any
from duckduckgo_search import DDGS
from core.model_runtime.entities.message_entities import SystemPromptMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
SUMMARY_PROMPT = """
User's query:
{query}
Here are the news results:
{content}
Please summarize the news in a few sentences.
"""
class DuckDuckGoNewsSearchTool(BuiltinTool):
"""
Tool for performing a news search using DuckDuckGo search engine.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get("query"),
"timelimit": tool_parameters.get("timelimit"),
"max_results": tool_parameters.get("max_results"),
"safesearch": "moderate",
"region": "wt-wt",
}
try:
response = list(DDGS().news(**query_dict))
if not response:
return [self.create_text_message("No news found matching your criteria.")]
except Exception as e:
return [self.create_text_message(f"Error searching news: {str(e)}")]
require_summary = tool_parameters.get("require_summary", False)
if require_summary:
results = "\n".join([f"{res.get('title')}: {res.get('body')}" for res in response])
results = self.summary_results(user_id=user_id, content=results, query=query_dict["keywords"])
return self.create_text_message(text=results)
# Create rich markdown content for each news item
markdown_result = "\n\n"
json_result = []
for res in response:
markdown_result += f"### {res.get('title', 'Untitled')}\n\n"
if res.get("date"):
markdown_result += f"**Date:** {res.get('date')}\n\n"
if res.get("body"):
markdown_result += f"{res.get('body')}\n\n"
if res.get("source"):
markdown_result += f"*Source: {res.get('source')}*\n\n"
if res.get("image"):
markdown_result += f"![{res.get('title', '')}]({res.get('image')})\n\n"
markdown_result += f"[Read more]({res.get('url', '')})\n\n---\n\n"
json_result.append(
self.create_json_message(
{
"title": res.get("title", ""),
"date": res.get("date", ""),
"body": res.get("body", ""),
"url": res.get("url", ""),
"image": res.get("image", ""),
"source": res.get("source", ""),
}
)
)
return [self.create_text_message(markdown_result)] + json_result
def summary_results(self, user_id: str, content: str, query: str) -> str:
prompt = SUMMARY_PROMPT.format(query=query, content=content)
summary = self.invoke_model(
user_id=user_id,
prompt_messages=[
SystemPromptMessage(content=prompt),
],
stop=[],
)
return summary.message.content

@ -0,0 +1,71 @@
identity:
name: ddgo_news
author: Assistant
label:
en_US: DuckDuckGo News Search
zh_Hans: DuckDuckGo 新闻搜索
description:
human:
en_US: Perform news searches on DuckDuckGo and get results.
zh_Hans: 在 DuckDuckGo 上进行新闻搜索并获取结果。
llm: Perform news searches on DuckDuckGo and get results.
parameters:
- name: query
type: string
required: true
label:
en_US: Query String
zh_Hans: 查询语句
human_description:
en_US: Search Query.
zh_Hans: 搜索查询语句。
llm_description: Key words for searching
form: llm
- name: max_results
type: number
required: true
default: 5
label:
en_US: Max Results
zh_Hans: 最大结果数量
human_description:
en_US: The Max Results
zh_Hans: 最大结果数量
form: form
- name: timelimit
type: select
required: false
options:
- value: Day
label:
en_US: Current Day
zh_Hans: 当天
- value: Week
label:
en_US: Current Week
zh_Hans: 本周
- value: Month
label:
en_US: Current Month
zh_Hans: 当月
- value: Year
label:
en_US: Current Year
zh_Hans: 今年
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
human_description:
en_US: Use when querying results within a specific time range only.
zh_Hans: 只查询一定时间范围内的结果时使用
form: form
- name: require_summary
type: boolean
default: false
label:
en_US: Require Summary
zh_Hans: 是否总结
human_description:
en_US: Whether to pass the news results to llm for summarization.
zh_Hans: 是否需要将新闻结果传给大模型总结
form: form

@ -0,0 +1,75 @@
from typing import Any, ClassVar
from duckduckgo_search import DDGS
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class DuckDuckGoVideoSearchTool(BuiltinTool):
"""
Tool for performing a video search using DuckDuckGo search engine.
"""
IFRAME_TEMPLATE: ClassVar[str] = """
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
max-width: 100%; border-radius: 8px;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="{src}"
frameborder="0"
allowfullscreen>
</iframe>
</div>"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get("query"),
"region": tool_parameters.get("region", "wt-wt"),
"safesearch": tool_parameters.get("safesearch", "moderate"),
"timelimit": tool_parameters.get("timelimit"),
"resolution": tool_parameters.get("resolution"),
"duration": tool_parameters.get("duration"),
"license_videos": tool_parameters.get("license_videos"),
"max_results": tool_parameters.get("max_results"),
}
# Remove None values to use API defaults
query_dict = {k: v for k, v in query_dict.items() if v is not None}
# Get proxy URL from parameters
proxy_url = tool_parameters.get("proxy_url", "").strip()
response = DDGS().videos(**query_dict)
# Create HTML result with embedded iframes
markdown_result = "\n\n"
json_result = []
for res in response:
title = res.get("title", "")
embed_html = res.get("embed_html", "")
description = res.get("description", "")
content_url = res.get("content", "")
# Handle TED.com videos
if not embed_html and "ted.com/talks" in content_url:
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
if proxy_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
# Original YouTube/other platform handling
elif embed_html:
embed_url = res.get("embed_url", "")
if proxy_url and embed_url:
embed_url = f"{proxy_url}{embed_url}"
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
markdown_result += f"{title}\n\n"
markdown_result += f"{embed_html}\n\n"
markdown_result += "---\n\n"
json_result.append(self.create_json_message(res))
return [self.create_text_message(markdown_result)] + json_result

@ -0,0 +1,97 @@
identity:
name: ddgo_video
author: Tao Wang
label:
en_US: DuckDuckGo Video Search
zh_Hans: DuckDuckGo 视频搜索
description:
human:
en_US: Search and embedded videos.
zh_Hans: 搜索并嵌入视频
llm: Search videos on duckduckgo and embed videos in iframe
parameters:
- name: query
label:
en_US: Query String
zh_Hans: 查询语句
type: string
required: true
human_description:
en_US: Search Query
zh_Hans: 搜索查询语句
llm_description: Key words for searching
form: llm
- name: max_results
label:
en_US: Max Results
zh_Hans: 最大结果数量
type: number
required: true
default: 3
minimum: 1
maximum: 10
human_description:
en_US: The max results (1-10)
zh_Hans: 最大结果数量1-10
form: form
- name: timelimit
label:
en_US: Result Time Limit
zh_Hans: 结果时间限制
type: select
required: false
options:
- value: Day
label:
en_US: Current Day
zh_Hans: 当天
- value: Week
label:
en_US: Current Week
zh_Hans: 本周
- value: Month
label:
en_US: Current Month
zh_Hans: 当月
- value: Year
label:
en_US: Current Year
zh_Hans: 今年
human_description:
en_US: Query results within a specific time range only
zh_Hans: 只查询一定时间范围内的结果时使用
form: form
- name: duration
label:
en_US: Video Duration
zh_Hans: 视频时长
type: select
required: false
options:
- value: short
label:
en_US: Short (<4 minutes)
zh_Hans: 短视频(<4分钟
- value: medium
label:
en_US: Medium (4-20 minutes)
zh_Hans: 中等4-20分钟
- value: long
label:
en_US: Long (>20 minutes)
zh_Hans: 长视频(>20分钟
human_description:
en_US: Filter videos by duration
zh_Hans: 按时长筛选视频
form: form
- name: proxy_url
label:
en_US: Proxy URL
zh_Hans: 视频代理地址
type: string
required: false
default: ""
human_description:
en_US: Proxy URL
zh_Hans: 视频代理地址
form: form

@ -38,7 +38,7 @@ def send_mail(parmas: SendEmailToolParameters):
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string()) server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
return True return True
except Exception as e: except Exception as e:
logging.exception("send email failed: %s", e) logging.exception("send email failed")
return False return False
else: # NONE or TLS else: # NONE or TLS
try: try:
@ -49,5 +49,5 @@ def send_mail(parmas: SendEmailToolParameters):
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string()) server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
return True return True
except Exception as e: except Exception as e:
logging.exception("send email failed: %s", e) logging.exception("send email failed")
return False return False

@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool):
invoke tools invoke tools
""" """
sender = self.runtime.credentials.get("email_account", "") sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "") password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "") smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server: if not smtp_server:

@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool):
invoke tools invoke tools
""" """
sender = self.runtime.credentials.get("email_account", "") sender = self.runtime.credentials.get("email_account", "")
email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
password = self.runtime.credentials.get("email_password", "") password = self.runtime.credentials.get("email_password", "")
smtp_server = self.runtime.credentials.get("smtp_server", "") smtp_server = self.runtime.credentials.get("smtp_server", "")
if not smtp_server: if not smtp_server:

@ -19,7 +19,7 @@ class WizperTool(BuiltinTool):
version = tool_parameters.get("version", "3") version = tool_parameters.get("version", "3")
if audio_file.type != FileType.AUDIO: if audio_file.type != FileType.AUDIO:
return [self.create_text_message("Not a valid audio file.")] return self.create_text_message("Not a valid audio file.")
api_key = self.runtime.credentials["fal_api_key"] api_key = self.runtime.credentials["fal_api_key"]
@ -31,9 +31,8 @@ class WizperTool(BuiltinTool):
try: try:
audio_url = fal_client.upload(file_data, mime_type) audio_url = fal_client.upload(file_data, mime_type)
except Exception as e: except Exception as e:
return [self.create_text_message(f"Error uploading audio file: {str(e)}")] return self.create_text_message(f"Error uploading audio file: {str(e)}")
arguments = { arguments = {
"audio_url": audio_url, "audio_url": audio_url,
@ -49,4 +48,9 @@ class WizperTool(BuiltinTool):
with_logs=False, with_logs=False,
) )
return self.create_json_message(result) json_message = self.create_json_message(result)
text = result.get("text", "")
text_message = self.create_text_message(text)
return [json_message, text_message]

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAIToolEmbedding(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"content-type": "application/json",
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
}
payload = {"inputs": tool_parameters.get("inputs")}
model = tool_parameters.get("model", "bge-m3")
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
return self.create_text_message(f"Got Error Response:{response.text}")
return [self.create_text_message(response.content.decode("utf-8"))]

@ -0,0 +1,37 @@
identity:
name: embedding
author: gitee_ai
label:
en_US: embedding
icon: icon.svg
description:
human:
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
llm: This tool is used to generate word embeddings from text input.
parameters:
- name: model
type: string
required: true
in: path
description:
en_US: Supported Embedding (compatible with OpenAI) interface models
enum:
- bge-m3
- bge-large-zh-v1.5
- bge-small-zh-v1.5
label:
en_US: Service Model
zh_Hans: 服务模型
default: bge-m3
form: form
- name: inputs
type: string
required: true
label:
en_US: Input Text
zh_Hans: 输入文本
human_description:
en_US: The text input used to generate embeddings.
zh_Hans: 用于生成词向量的输入文本。
llm_description: This text input will be used to generate embeddings.
form: llm

@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
class GiteeAITool(BuiltinTool): class GiteeAIToolText2Image(BuiltinTool):
def _invoke( def _invoke(
self, user_id: str, tool_parameters: dict[str, Any] self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:

@ -1,14 +1,12 @@
identity: identity:
author: Yash Parmar author: Yash Parmar, Kalo Chin
name: tavily name: tavily
label: label:
en_US: Tavily en_US: Tavily Search & Extract
zh_Hans: Tavily zh_Hans: Tavily 搜索和提取
pt_BR: Tavily
description: description:
en_US: Tavily en_US: A powerful AI-native search engine and web content extraction tool that provides highly relevant search results and raw content extraction from web pages.
zh_Hans: Tavily zh_Hans: 一个强大的原生AI搜索引擎和网页内容提取工具提供高度相关的搜索结果和网页原始内容提取。
pt_BR: Tavily
icon: icon.png icon: icon.png
tags: tags:
- search - search
@ -19,13 +17,10 @@ credentials_for_provider:
label: label:
en_US: Tavily API key en_US: Tavily API key
zh_Hans: Tavily API key zh_Hans: Tavily API key
pt_BR: Tavily API key
placeholder: placeholder:
en_US: Please input your Tavily API key en_US: Please input your Tavily API key
zh_Hans: 请输入你的 Tavily API key zh_Hans: 请输入你的 Tavily API key
pt_BR: Please input your Tavily API key
help: help:
en_US: Get your Tavily API key from Tavily en_US: Get your Tavily API key from Tavily
zh_Hans: 从 TavilyApi 获取您的 Tavily API key zh_Hans: 从 TavilyApi 获取您的 Tavily API key
pt_BR: Get your Tavily API key from Tavily url: https://app.tavily.com/home
url: https://docs.tavily.com/docs/welcome

@ -0,0 +1,145 @@
from typing import Any
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
TAVILY_API_URL = "https://api.tavily.com"
class TavilyExtract:
"""
A class for extracting content from web pages using the Tavily Extract API.
Args:
api_key (str): The API key for accessing the Tavily Extract API.
Methods:
extract_content: Retrieves extracted content from the Tavily Extract API.
"""
def __init__(self, api_key: str) -> None:
self.api_key = api_key
def extract_content(self, params: dict[str, Any]) -> dict:
"""
Retrieves extracted content from the Tavily Extract API.
Args:
params (Dict[str, Any]): The extraction parameters.
Returns:
dict: The extracted content.
"""
# Ensure required parameters are set
if "api_key" not in params:
params["api_key"] = self.api_key
# Process parameters
processed_params = self._process_params(params)
response = requests.post(f"{TAVILY_API_URL}/extract", json=processed_params)
response.raise_for_status()
return response.json()
def _process_params(self, params: dict[str, Any]) -> dict:
"""
Processes and validates the extraction parameters.
Args:
params (Dict[str, Any]): The extraction parameters.
Returns:
dict: The processed parameters.
"""
processed_params = {}
# Process 'urls'
if "urls" in params:
urls = params["urls"]
if isinstance(urls, str):
processed_params["urls"] = [url.strip() for url in urls.replace(",", " ").split()]
elif isinstance(urls, list):
processed_params["urls"] = urls
else:
raise ValueError("The 'urls' parameter is required.")
# Only include 'api_key'
processed_params["api_key"] = params.get("api_key", self.api_key)
return processed_params
class TavilyExtractTool(BuiltinTool):
"""
A tool for extracting content from web pages using Tavily Extract.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invokes the Tavily Extract tool with the given user ID and tool parameters.
Args:
user_id (str): The ID of the user invoking the tool.
tool_parameters (Dict[str, Any]): The parameters for the Tavily Extract tool.
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily Extract tool invocation.
"""
urls = tool_parameters.get("urls", "")
api_key = self.runtime.credentials.get("tavily_api_key")
if not api_key:
return self.create_text_message(
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
)
if not urls:
return self.create_text_message("Please input at least one URL to extract.")
tavily_extract = TavilyExtract(api_key)
try:
raw_results = tavily_extract.extract_content(tool_parameters)
except requests.HTTPError as e:
return self.create_text_message(f"Error occurred while extracting content: {str(e)}")
if not raw_results.get("results"):
return self.create_text_message("No content could be extracted from the provided URLs.")
else:
# Always return JSON message with all data
json_message = self.create_json_message(raw_results)
# Create text message based on user-selected parameters
text_message_content = self._format_results_as_text(raw_results)
text_message = self.create_text_message(text=text_message_content)
return [json_message, text_message]
def _format_results_as_text(self, raw_results: dict) -> str:
"""
Formats the raw extraction results into a markdown text based on user-selected parameters.
Args:
raw_results (dict): The raw extraction results.
Returns:
str: The formatted markdown text.
"""
output_lines = []
for idx, result in enumerate(raw_results.get("results", []), 1):
url = result.get("url", "")
raw_content = result.get("raw_content", "")
output_lines.append(f"## Extracted Content {idx}: {url}\n")
output_lines.append(f"**Raw Content:**\n{raw_content}\n")
output_lines.append("---\n")
if raw_results.get("failed_results"):
output_lines.append("## Failed URLs:\n")
for failed in raw_results["failed_results"]:
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
output_lines.append(f"- {url}: {error}\n")
return "\n".join(output_lines)

@ -0,0 +1,23 @@
identity:
name: tavily_extract
author: Kalo Chin
label:
en_US: Tavily Extract
zh_Hans: Tavily Extract
description:
human:
en_US: A web extraction tool built specifically for AI agents (LLMs), delivering raw content from web pages.
zh_Hans: 专为人工智能代理 (LLM) 构建的网页提取工具,提供网页的原始内容。
llm: A tool for extracting raw content from web pages, designed for AI agents (LLMs).
parameters:
- name: urls
type: string
required: true
label:
en_US: URLs
zh_Hans: URLs
human_description:
en_US: A comma-separated list of URLs to extract content from.
zh_Hans: 要从中提取内容的 URL 的逗号分隔列表。
llm_description: A comma-separated list of URLs to extract content from.
form: llm

@ -17,8 +17,6 @@ class TavilySearch:
Methods: Methods:
raw_results: Retrieves raw search results from the Tavily Search API. raw_results: Retrieves raw search results from the Tavily Search API.
results: Retrieves cleaned search results from the Tavily Search API.
clean_results: Cleans the raw search results.
""" """
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
@ -35,63 +33,62 @@ class TavilySearch:
dict: The raw search results. dict: The raw search results.
""" """
# Ensure required parameters are set
params["api_key"] = self.api_key params["api_key"] = self.api_key
if (
"exclude_domains" in params
and isinstance(params["exclude_domains"], str)
and params["exclude_domains"] != "None"
):
params["exclude_domains"] = params["exclude_domains"].split()
else:
params["exclude_domains"] = []
if (
"include_domains" in params
and isinstance(params["include_domains"], str)
and params["include_domains"] != "None"
):
params["include_domains"] = params["include_domains"].split()
else:
params["include_domains"] = []
response = requests.post(f"{TAVILY_API_URL}/search", json=params) # Process parameters to ensure correct types
processed_params = self._process_params(params)
response = requests.post(f"{TAVILY_API_URL}/search", json=processed_params)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def results(self, params: dict[str, Any]) -> list[dict]: def _process_params(self, params: dict[str, Any]) -> dict:
""" """
Retrieves cleaned search results from the Tavily Search API. Processes and validates the search parameters.
Args: Args:
params (Dict[str, Any]): The search parameters. params (Dict[str, Any]): The search parameters.
Returns: Returns:
list: The cleaned search results. dict: The processed parameters.
""" """
raw_search_results = self.raw_results(params) processed_params = {}
return self.clean_results(raw_search_results["results"])
for key, value in params.items():
def clean_results(self, results: list[dict]) -> list[dict]: if value is None or value == "None":
""" continue
Cleans the raw search results. if key in ["include_domains", "exclude_domains"]:
if isinstance(value, str):
# Split the string by commas or spaces and strip whitespace
processed_params[key] = [domain.strip() for domain in value.replace(",", " ").split()]
elif key in ["include_images", "include_image_descriptions", "include_answer", "include_raw_content"]:
# Ensure boolean type
if isinstance(value, str):
processed_params[key] = value.lower() == "true"
else:
processed_params[key] = bool(value)
elif key in ["max_results", "days"]:
if isinstance(value, str):
processed_params[key] = int(value)
else:
processed_params[key] = value
elif key in ["search_depth", "topic", "query", "api_key"]:
processed_params[key] = value
else:
# Unrecognized parameter
pass
Args: # Set defaults if not present
results (list): The raw search results. processed_params.setdefault("search_depth", "basic")
processed_params.setdefault("topic", "general")
processed_params.setdefault("max_results", 5)
Returns: # If topic is 'news', ensure 'days' is set
list: The cleaned search results. if processed_params.get("topic") == "news":
processed_params.setdefault("days", 3)
""" return processed_params
clean_results = []
for result in results:
clean_results.append(
{
"url": result["url"],
"content": result["content"],
}
)
# return clean results as a string
return "\n".join([f"{res['url']}\n{res['content']}" for res in clean_results])
class TavilySearchTool(BuiltinTool): class TavilySearchTool(BuiltinTool):
@ -111,14 +108,88 @@ class TavilySearchTool(BuiltinTool):
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation. ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation.
""" """
query = tool_parameters.get("query", "") query = tool_parameters.get("query", "")
api_key = self.runtime.credentials.get("tavily_api_key")
api_key = self.runtime.credentials["tavily_api_key"] if not api_key:
return self.create_text_message(
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
)
if not query: if not query:
return self.create_text_message("Please input query") return self.create_text_message("Please input a query.")
tavily_search = TavilySearch(api_key) tavily_search = TavilySearch(api_key)
results = tavily_search.results(tool_parameters) try:
print(results) raw_results = tavily_search.raw_results(tool_parameters)
if not results: except requests.HTTPError as e:
return self.create_text_message(f"No results found for '{query}' in Tavily") return self.create_text_message(f"Error occurred while searching: {str(e)}")
if not raw_results.get("results"):
return self.create_text_message(f"No results found for '{query}' in Tavily.")
else:
# Always return JSON message with all data
json_message = self.create_json_message(raw_results)
# Create text message based on user-selected parameters
text_message_content = self._format_results_as_text(raw_results, tool_parameters)
text_message = self.create_text_message(text=text_message_content)
return [json_message, text_message]
def _format_results_as_text(self, raw_results: dict, tool_parameters: dict[str, Any]) -> str:
"""
Formats the raw results into a markdown text based on user-selected parameters.
Args:
raw_results (dict): The raw search results.
tool_parameters (dict): The tool parameters selected by the user.
Returns:
str: The formatted markdown text.
"""
output_lines = []
# Include answer if requested
if tool_parameters.get("include_answer", False) and raw_results.get("answer"):
output_lines.append(f"**Answer:** {raw_results['answer']}\n")
# Include images if requested
if tool_parameters.get("include_images", False) and raw_results.get("images"):
output_lines.append("**Images:**\n")
for image in raw_results["images"]:
if tool_parameters.get("include_image_descriptions", False) and "description" in image:
output_lines.append(f"![{image['description']}]({image['url']})\n")
else: else:
return self.create_text_message(text=results) output_lines.append(f"![]({image['url']})\n")
# Process each result
if "results" in raw_results:
for idx, result in enumerate(raw_results["results"], 1):
title = result.get("title", "No Title")
url = result.get("url", "")
content = result.get("content", "")
published_date = result.get("published_date", "")
score = result.get("score", "")
output_lines.append(f"### Result {idx}: [{title}]({url})\n")
# Include published date if available and topic is 'news'
if tool_parameters.get("topic") == "news" and published_date:
output_lines.append(f"**Published Date:** {published_date}\n")
output_lines.append(f"**URL:** {url}\n")
# Include score (relevance)
if score:
output_lines.append(f"**Relevance Score:** {score}\n")
# Include content
if content:
output_lines.append(f"**Content:**\n{content}\n")
# Include raw content if requested
if tool_parameters.get("include_raw_content", False) and result.get("raw_content"):
output_lines.append(f"**Raw Content:**\n{result['raw_content']}\n")
# Add a separator
output_lines.append("---\n")
return "\n".join(output_lines)

@ -2,28 +2,24 @@ identity:
name: tavily_search name: tavily_search
author: Yash Parmar author: Yash Parmar
label: label:
en_US: TavilySearch en_US: Tavily Search
zh_Hans: TavilySearch zh_Hans: Tavily Search
pt_BR: TavilySearch
description: description:
human: human:
en_US: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed. en_US: A search engine tool built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
zh_Hans: 专为人工智能代理 (LLM) 构建的搜索引擎工具,可快速提供实时、准确和真实的结果。 zh_Hans: 专为人工智能代理 (LLM) 构建的搜索引擎工具,可快速提供实时、准确和真实的结果。
pt_BR: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
llm: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed. llm: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed.
parameters: parameters:
- name: query - name: query
type: string type: string
required: true required: true
label: label:
en_US: Query string en_US: Query
zh_Hans: 查询语句 zh_Hans: 查询
pt_BR: Query string
human_description: human_description:
en_US: used for searching en_US: The search query you want to execute with Tavily.
zh_Hans: 用于搜索网页内容 zh_Hans: 您想用 Tavily 执行的搜索查询。
pt_BR: used for searching llm_description: The search query.
llm_description: key words for searching
form: llm form: llm
- name: search_depth - name: search_depth
type: select type: select
@ -31,122 +27,118 @@ parameters:
label: label:
en_US: Search Depth en_US: Search Depth
zh_Hans: 搜索深度 zh_Hans: 搜索深度
pt_BR: Search Depth
human_description: human_description:
en_US: The depth of search results en_US: The depth of the search.
zh_Hans: 搜索结果的深度 zh_Hans: 搜索的深度。
pt_BR: The depth of search results
form: form form: form
options: options:
- value: basic - value: basic
label: label:
en_US: Basic en_US: Basic
zh_Hans: 基本 zh_Hans: 基本
pt_BR: Basic
- value: advanced - value: advanced
label: label:
en_US: Advanced en_US: Advanced
zh_Hans: 高级 zh_Hans: 高级
pt_BR: Advanced
default: basic default: basic
- name: topic
type: select
required: false
label:
en_US: Topic
zh_Hans: 主题
human_description:
en_US: The category of the search.
zh_Hans: 搜索的类别。
form: form
options:
- value: general
label:
en_US: General
zh_Hans: 一般
- value: news
label:
en_US: News
zh_Hans: 新闻
default: general
- name: days
type: number
required: false
label:
en_US: Days
zh_Hans: 天数
human_description:
en_US: The number of days back from the current date to include in the search results (only applicable when "topic" is "news").
zh_Hans: 从当前日期起向前追溯的天数以包含在搜索结果中仅当“topic”为“news”时适用
form: form
min: 1
default: 3
- name: max_results
type: number
required: false
label:
en_US: Max Results
zh_Hans: 最大结果数
human_description:
en_US: The maximum number of search results to return.
zh_Hans: 要返回的最大搜索结果数。
form: form
min: 1
max: 20
default: 5
- name: include_images - name: include_images
type: boolean type: boolean
required: false required: false
label: label:
en_US: Include Images en_US: Include Images
zh_Hans: 包含图片 zh_Hans: 包含图片
pt_BR: Include Images
human_description: human_description:
en_US: Include images in the search results en_US: Include a list of query-related images in the response.
zh_Hans: 在搜索结果中包含图片 zh_Hans: 在响应中包含与查询相关的图片列表。
pt_BR: Include images in the search results
form: form form: form
options: default: false
- value: 'true' - name: include_image_descriptions
label: type: boolean
en_US: 'Yes' required: false
zh_Hans:
pt_BR: 'Yes'
- value: 'false'
label: label:
en_US: 'No' en_US: Include Image Descriptions
zh_Hans: zh_Hans: 包含图片描述
pt_BR: 'No' human_description:
default: 'false' en_US: When include_images is True, adds descriptive text for each image.
zh_Hans: 当 include_images 为 True 时,为每个图像添加描述文本。
form: form
default: false
- name: include_answer - name: include_answer
type: boolean type: boolean
required: false required: false
label: label:
en_US: Include Answer en_US: Include Answer
zh_Hans: 包含答案 zh_Hans: 包含答案
pt_BR: Include Answer
human_description: human_description:
en_US: Include answers in the search results en_US: Include a short answer to the original query in the response.
zh_Hans: 在搜索结果中包含答案 zh_Hans: 在响应中包含对原始查询的简短回答。
pt_BR: Include answers in the search results
form: form form: form
options: default: false
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
pt_BR: 'Yes'
- value: 'false'
label:
en_US: 'No'
zh_Hans:
pt_BR: 'No'
default: 'false'
- name: include_raw_content - name: include_raw_content
type: boolean type: boolean
required: false required: false
label: label:
en_US: Include Raw Content en_US: Include Raw Content
zh_Hans: 包含原始内容 zh_Hans: 包含原始内容
pt_BR: Include Raw Content
human_description: human_description:
en_US: Include raw content in the search results en_US: Include the cleaned and parsed HTML content of each search result.
zh_Hans: 在搜索结果中包含原始内容 zh_Hans: 包含每个搜索结果的已清理和解析的HTML内容。
pt_BR: Include raw content in the search results
form: form form: form
options: default: false
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
pt_BR: 'Yes'
- value: 'false'
label:
en_US: 'No'
zh_Hans:
pt_BR: 'No'
default: 'false'
- name: max_results
type: number
required: false
label:
en_US: Max Results
zh_Hans: 最大结果
pt_BR: Max Results
human_description:
en_US: The number of maximum search results to return
zh_Hans: 返回的最大搜索结果数
pt_BR: The number of maximum search results to return
form: form
min: 1
max: 20
default: 5
- name: include_domains - name: include_domains
type: string type: string
required: false required: false
label: label:
en_US: Include Domains en_US: Include Domains
zh_Hans: 包含域 zh_Hans: 包含域
pt_BR: Include Domains
human_description: human_description:
en_US: A list of domains to specifically include in the search results en_US: A comma-separated list of domains to specifically include in the search results.
zh_Hans: 在搜索结果中特别包含的域名列表 zh_Hans: 要在搜索结果中特别包含的域的逗号分隔列表。
pt_BR: A list of domains to specifically include in the search results
form: form form: form
- name: exclude_domains - name: exclude_domains
type: string type: string
@ -154,9 +146,7 @@ parameters:
label: label:
en_US: Exclude Domains en_US: Exclude Domains
zh_Hans: 排除域 zh_Hans: 排除域
pt_BR: Exclude Domains
human_description: human_description:
en_US: A list of domains to specifically exclude from the search results en_US: A comma-separated list of domains to specifically exclude from the search results.
zh_Hans: 从搜索结果中特别排除的域名列表 zh_Hans: 要从搜索结果中特别排除的域的逗号分隔列表。
pt_BR: A list of domains to specifically exclude from the search results
form: form form: form

@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="800px" height="800px" viewBox="0 -38 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink" preserveAspectRatio="xMidYMid">
<g>
<path d="M250.346231,28.0746923 C247.358133,17.0320558 238.732098,8.40602109 227.689461,5.41792308 C207.823743,0 127.868333,0 127.868333,0 C127.868333,0 47.9129229,0.164179487 28.0472049,5.58210256 C17.0045684,8.57020058 8.37853373,17.1962353 5.39043571,28.2388718 C-0.618533519,63.5374615 -2.94988224,117.322662 5.5546152,151.209308 C8.54271322,162.251944 17.1687479,170.877979 28.2113844,173.866077 C48.0771024,179.284 128.032513,179.284 128.032513,179.284 C128.032513,179.284 207.987923,179.284 227.853641,173.866077 C238.896277,170.877979 247.522312,162.251944 250.51041,151.209308 C256.847738,115.861464 258.801474,62.1091 250.346231,28.0746923 Z"
fill="#FF0000">
</path>
<polygon fill="#FFFFFF" points="102.420513 128.06 168.749025 89.642 102.420513 51.224">
</polygon>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

@ -0,0 +1,81 @@
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api import YouTubeTranscriptApi
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class YouTubeTranscriptTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke the YouTube transcript tool
"""
try:
# Extract parameters with defaults
video_input = tool_parameters["video_id"]
language = tool_parameters.get("language")
output_format = tool_parameters.get("format", "text")
preserve_formatting = tool_parameters.get("preserve_formatting", False)
proxy = tool_parameters.get("proxy")
cookies = tool_parameters.get("cookies")
# Extract video ID from URL if needed
video_id = self._extract_video_id(video_input)
# Common kwargs for API calls
kwargs = {"proxies": {"https": proxy} if proxy else None, "cookies": cookies}
try:
if language:
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id, **kwargs)
try:
transcript = transcript_list.find_transcript([language])
except:
# If requested language not found, try translating from English
transcript = transcript_list.find_transcript(["en"]).translate(language)
transcript_data = transcript.fetch()
else:
transcript_data = YouTubeTranscriptApi.get_transcript(
video_id, preserve_formatting=preserve_formatting, **kwargs
)
# Format output
formatter_class = {
"json": "JSONFormatter",
"pretty": "PrettyPrintFormatter",
"srt": "SRTFormatter",
"vtt": "WebVTTFormatter",
}.get(output_format)
if formatter_class:
from youtube_transcript_api import formatters
formatter = getattr(formatters, formatter_class)()
formatted_transcript = formatter.format_transcript(transcript_data)
else:
formatted_transcript = " ".join(entry["text"] for entry in transcript_data)
return self.create_text_message(text=formatted_transcript)
except Exception as e:
return self.create_text_message(text=f"Error getting transcript: {str(e)}")
except Exception as e:
return self.create_text_message(text=f"Error processing request: {str(e)}")
def _extract_video_id(self, video_input: str) -> str:
"""
Extract video ID from URL or return as-is if already an ID
"""
if "youtube.com" in video_input or "youtu.be" in video_input:
# Parse URL
parsed_url = urlparse(video_input)
if "youtube.com" in parsed_url.netloc:
return parse_qs(parsed_url.query)["v"][0]
else: # youtu.be
return parsed_url.path[1:]
return video_input # Assume it's already a video ID

@ -0,0 +1,101 @@
identity:
name: free_youtube_transcript
author: Tao Wang
label:
en_US: Free YouTube Transcript API
zh_Hans: 免费获取 YouTube 转录
description:
human:
en_US: Get transcript from a YouTube video for free.
zh_Hans: 免费获取 YouTube 视频的转录文案。
llm: A tool for retrieving transcript from YouTube videos.
parameters:
- name: video_id
type: string
required: true
label:
en_US: Video ID/URL
zh_Hans: 视频ID
human_description:
en_US: Used to define the video from which the transcript will be fetched. You can find the id in the video url. For example - https://www.youtube.com/watch?v=video_id.
zh_Hans: 您要哪条视频的转录文案您可以在视频链接中找到id。例如 - https://www.youtube.com/watch?v=video_id。
llm_description: Used to define the video from which the transcript will be fetched. For example - https://www.youtube.com/watch?v=video_id.
form: llm
- name: language
type: string
required: false
label:
en_US: Language Code
zh_Hans: 语言
human_description:
en_US: Language code (e.g. 'en', 'zh') for the transcript.
zh_Hans: 字幕语言代码(如'en'、'zh')。留空则自动选择。
llm_description: Used to set the language for transcripts.
form: form
- name: format
type: select
required: false
default: text
options:
- value: text
label:
en_US: Plain Text
zh_Hans: 纯文本
- value: json
label:
en_US: JSON Format
zh_Hans: JSON 格式
- value: pretty
label:
en_US: Pretty Print Format
zh_Hans: 美化格式
- value: srt
label:
en_US: SRT Format
zh_Hans: SRT 格式
- value: vtt
label:
en_US: WebVTT Format
zh_Hans: WebVTT 格式
label:
en_US: Output Format
zh_Hans: 输出格式
human_description:
en_US: Format of the transcript output
zh_Hans: 字幕输出格式
llm_description: The format to output the transcript in. Options are text (plain text), json (raw transcript data), srt (SubRip format), or vtt (WebVTT format)
form: form
- name: preserve_formatting
type: boolean
required: false
default: false
label:
en_US: Preserve Formatting
zh_Hans: 保留格式
human_description:
en_US: Keep HTML formatting elements like <i> (italics) and <b> (bold)
zh_Hans: 保留HTML格式元素如<i>(斜体)和<b>(粗体)
llm_description: Whether to preserve HTML formatting elements in the transcript text
form: form
- name: proxy
type: string
required: false
label:
en_US: HTTPS Proxy
zh_Hans: HTTPS 代理
human_description:
en_US: HTTPS proxy URL (e.g. https://user:pass@domain:port)
zh_Hans: HTTPS 代理地址(如 https://user:pass@domain:port
llm_description: HTTPS proxy to use for the request. Format should be https://user:pass@domain:port
form: form
- name: cookies
type: string
required: false
label:
en_US: Cookies File Path
zh_Hans: Cookies 文件路径
human_description:
en_US: Path to cookies.txt file for accessing age-restricted videos
zh_Hans: 用于访问年龄限制视频的 cookies.txt 文件路径
llm_description: Path to a cookies.txt file containing YouTube cookies, needed for accessing age-restricted videos
form: form

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save