Merge main

pull/9184/head
Yeuoly 2 years ago
commit 00d1c45518
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in ["console", "inner_api"]: if request.blueprint not in {"console", "inner_api"}:
return None return None
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")

@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
) )
@click.confirmation_option( @click.confirmation_option(
prompt=click.style( prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" "Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
) )
) )
def reset_encrypt_key_pair(): def reset_encrypt_key_pair():
@ -131,7 +131,7 @@ def reset_encrypt_key_pair():
click.echo( click.echo(
click.style( click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green", fg="green",
) )
) )
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.") @click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str): def vdb_migrate(scope: str):
if scope in ["knowledge", "all"]: if scope in {"knowledge", "all"}:
migrate_knowledge_vector_database() migrate_knowledge_vector_database()
if scope in ["annotation", "all"]: if scope in {"annotation", "all"}:
migrate_annotation_vector_database() migrate_annotation_vector_database()
@ -275,8 +275,7 @@ def migrate_knowledge_vector_database():
for dataset in datasets: for dataset in datasets:
total_count = total_count + 1 total_count = total_count + 1
click.echo( click.echo(
f"Processing the {total_count} dataset {dataset.id}. " f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
+ f"{create_count} created, {skipped_count} skipped."
) )
try: try:
click.echo("Create dataset vdb index: {}".format(dataset.id)) click.echo("Create dataset vdb index: {}".format(dataset.id))
@ -411,7 +410,8 @@ def migrate_knowledge_vector_database():
try: try:
click.echo( click.echo(
click.style( click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", f"Start to created vector index with {len(documents)} documents of {segments_count}"
f" segments for dataset {dataset.id}.",
fg="green", fg="green",
) )
) )
@ -593,7 +593,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
click.echo( click.echo(
click.style( click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), "Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
fg="green", fg="green",
) )
) )

@ -110,6 +110,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
default=1000, default=1000,
) )
class PluginConfig(BaseSettings): class PluginConfig(BaseSettings):
""" """
Plugin configs Plugin configs
@ -124,6 +125,7 @@ class PluginConfig(BaseSettings):
default='dify-inner-api-key', default='dify-inner-api-key',
) )
class EndpointConfig(BaseSettings): class EndpointConfig(BaseSettings):
""" """
Module URL configs Module URL configs
@ -142,12 +144,12 @@ class EndpointConfig(BaseSettings):
) )
SERVICE_API_URL: str = Field( SERVICE_API_URL: str = Field(
description="Service API Url prefix." "used to display Service API Base Url to the front-end.", description="Service API Url prefix. used to display Service API Base Url to the front-end.",
default="", default="",
) )
APP_WEB_URL: str = Field( APP_WEB_URL: str = Field(
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.", description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
default="", default="",
) )
@ -285,7 +287,7 @@ class LoggingConfig(BaseSettings):
""" """
LOG_LEVEL: str = Field( LOG_LEVEL: str = Field(
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.", description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
default="INFO", default="INFO",
) )

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.8.0", default="0.8.2",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] if args["desc"] else "" desc = args["desc"] or ""
copy_right = args["copyright"] if args["copyright"] else "" copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" custom_disclaimer = args["custom_disclaimer"] or ""
else: else:
desc = site.description if site.description else args["desc"] if args["desc"] else "" desc = site.description or args["desc"] or ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" copy_right = site.copyright or args["copyright"] or ""
privacy_policy = ( privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()

@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id): def post(self, resource_id):
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = (

@ -94,19 +94,15 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields, conversation_pagination_fields,
conversation_with_summary_pagination_fields, conversation_with_summary_pagination_fields,
) )
from libs.helper import datetime_string from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
) )
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
) )

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import datetime_string from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
@ -25,14 +25,17 @@ class DailyMessageStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM messages where app_id = :app_id COUNT(*) AS message_count
""" FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -45,7 +48,7 @@ class DailyMessageStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -55,10 +58,10 @@ class DailyMessageStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -79,14 +82,17 @@ class DailyConversationStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM messages where app_id = :app_id COUNT(DISTINCT messages.conversation_id) AS conversation_count
""" FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -99,7 +105,7 @@ class DailyConversationStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -109,10 +115,10 @@ class DailyConversationStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -133,14 +139,17 @@ class DailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM messages where app_id = :app_id COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
""" FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -153,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -163,10 +172,10 @@ class DailyTerminalsStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -187,16 +196,18 @@ class DailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, (SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
sum(total_price) as total_price SUM(total_price) AS total_price
FROM messages where app_id = :app_id FROM
""" messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -209,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -219,10 +230,10 @@ class DailyTokenCostStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -245,16 +256,26 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = """SELECT
AVG(subquery.message_count) AS interactions DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count AVG(subquery.message_count) AS interactions
FROM conversations c FROM
JOIN messages m ON c.id = m.conversation_id (
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" SELECT
m.conversation_id,
COUNT(m.id) AS message_count
FROM
conversations c
JOIN
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -267,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and c.created_at >= :start" sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -277,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and c.created_at < :end" sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += """ sql_query += """
GROUP BY m.conversation_id) subquery GROUP BY m.conversation_id
LEFT JOIN conversations c on c.id=subquery.conversation_id ) subquery
GROUP BY date LEFT JOIN
ORDER BY date""" conversations c
ON c.id = subquery.conversation_id
GROUP BY
date
ORDER BY
date"""
response_data = [] response_data = []
@ -307,17 +333,21 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count COUNT(m.id) AS message_count,
FROM messages m COUNT(mf.id) AS feedback_count
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' FROM
WHERE m.app_id = :app_id messages m
""" LEFT JOIN
message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -330,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and m.created_at >= :start" sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -340,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and m.created_at < :end" sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -369,16 +399,17 @@ class AverageResponseTimeStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) as latency AVG(provider_response_latency) AS latency
FROM messages FROM
WHERE app_id = :app_id messages
""" WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -391,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -401,10 +432,10 @@ class AverageResponseTimeStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -425,17 +456,20 @@ class TokensPerSecondStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = """SELECT
CASE DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0 WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second END as tokens_per_second
FROM messages FROM
WHERE app_id = :app_id""" messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -448,7 +482,7 @@ WHERE app_id = :app_id"""
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -458,10 +492,10 @@ WHERE app_id = :app_id"""
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []

@ -502,6 +502,6 @@ api.add_resource(
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish") api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs") api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource( api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>" DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
) )
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow") api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import datetime_string from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom from models.workflow import WorkflowRunTriggeredFrom
@ -26,16 +26,18 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM workflow_runs COUNT(id) AS runs
WHERE app_id = :app_id FROM
AND triggered_from = :triggered_from workflow_runs
""" WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -86,16 +88,18 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM workflow_runs COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
WHERE app_id = :app_id FROM
AND triggered_from = :triggered_from workflow_runs
""" WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -146,18 +150,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SUM(workflow_runs.total_tokens) AS token_count
SUM(workflow_runs.total_tokens) as token_count FROM
FROM workflow_runs workflow_runs
WHERE app_id = :app_id WHERE
AND triggered_from = :triggered_from app_id = :app_id
""" AND triggered_from = :triggered_from"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date" sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
@ -213,27 +217,31 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """SELECT
SELECT AVG(sub.interactions) AS interactions,
AVG(sub.interactions) as interactions, sub.date
sub.date FROM
FROM (
(SELECT SELECT
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by, c.created_by,
COUNT(c.id) AS interactions COUNT(c.id) AS interactions
FROM workflow_runs c FROM
WHERE c.app_id = :app_id workflow_runs c
AND c.triggered_from = :triggered_from WHERE
{{start}} c.app_id = :app_id
{{end}} AND c.triggered_from = :triggered_from
GROUP BY date, c.created_by) sub {{start}}
GROUP BY sub.date {{end}}
""" GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
else: else:
sql_query = sql_query.replace("{{end}}", "") sql_query = sql_query.replace("{{end}}", "")

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, str_len, timezone from libs.helper import StrLen, email, timezone
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import RegisterService from services.account_service import RegisterService
@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json" "interface_language", type=supported_language, required=True, nullable=False, location="json"

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info) account = _generate_account(provider, user_info)
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account: if not account:
# Create account # Create account
account_name = user_info.name if user_info.name else "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
) )

@ -399,7 +399,7 @@ class DatasetIndexingEstimateApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return { return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
class DatasetRetrievalSettingApi(Resource): class DatasetRetrievalSettingApi(Resource):

@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
if document.indexing_status in ["completed", "error"]: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = [] info_list = []
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in ["completed", "error"]: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
# format document files info # format document files info
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit() db.session.commit()
elif action == "resume": elif action == "resume":
if document.indexing_status not in ["paused", "error"]: if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.") raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None document.paused_by = None

@ -18,9 +18,7 @@ class NotSetupError(BaseHTTPException):
class NotInitValidateError(BaseHTTPException): class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated" error_code = "not_init_validated"
description = ( description = "Init validation has not been completed yet. Please proceed with the init validation process first."
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
code = 401 code = 401

@ -81,19 +81,15 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id): def delete(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id, "app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned, "is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at, "last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"], "editable": current_user.role in {"owner", "admin"},
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
} }
for installed_app in installed_apps for installed_app in installed_apps

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
message_id = str(message_id) message_id = str(message_id)

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import str_len from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("password", type=str_len(30), required=True, location="json") parser.add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"] input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import email, get_remote_ip, str_len from libs.helper import StrLen, email, get_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()

@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate") api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>") api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource( api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>" ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
) )
api.add_resource( api.add_resource(

@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ApiToolManageService.test_api_tool_preview( return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id, current_user.current_tenant_id,
args["provider_name"] if args["provider_name"] else "", args["provider_name"] or "",
args["tool_name"], args["tool_name"],
args["credentials"], args["credentials"],
args["parameters"], args["parameters"],

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError() raise TooManyFilesError()
extension = file.filename.split(".")[-1] extension = file.filename.split(".")[-1]
if extension.lower() not in ["svg", "png"]: if extension.lower() not in {"svg", "png"}:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
try: try:

@ -64,7 +64,8 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.") abort(403, "The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets # The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
source = request.args.get("source") source = request.args.get("source")
if source == "datasets": if source == "datasets":
abort(403, "The number of documents has reached the limit of your subscription.") abort(403, "The number of documents has reached the limit of your subscription.")

@ -38,6 +38,7 @@ class PluginInvokeLLMApi(Resource):
return compact_generate_response(generator()) return compact_generate_response(generator())
class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeTextEmbeddingApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -113,6 +114,7 @@ class PluginInvokeNodeApi(Resource):
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode):
pass pass
class PluginInvokeAppApi(Resource): class PluginInvokeAppApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -134,6 +136,7 @@ class PluginInvokeAppApi(Resource):
PluginAppBackwardsInvocation.convert_to_event_stream(response) PluginAppBackwardsInvocation.convert_to_event_stream(response)
) )
class PluginInvokeEncryptApi(Resource): class PluginInvokeEncryptApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@ -145,6 +148,7 @@ class PluginInvokeEncryptApi(Resource):
""" """
return PluginEncrypter.invoke_encrypt(tenant_model, payload) return PluginEncrypter.invoke_encrypt(tenant_model, payload)
api.add_resource(PluginInvokeLLMApi, '/invoke/llm') api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank') api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')

@ -48,6 +48,7 @@ def get_tenant(view: Optional[Callable] = None):
else: else:
return decorator(view) return decorator(view)
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func): def decorator(view_func):
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):

@ -63,6 +63,7 @@ def enterprise_inner_api_user_auth(view):
return decorated return decorated
def plugin_inner_api_only(view): def plugin_inner_api_only(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

@ -79,19 +79,15 @@ class TextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id): def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id): def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
try: try:

@ -1,6 +1,7 @@
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.service_api import api from controllers.service_api import api
@ -22,10 +23,12 @@ from core.errors.error import (
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -113,6 +116,30 @@ class WorkflowTaskStopApi(Resource):
return {"result": "success"} return {"result": "success"}
class WorkflowAppLogApi(Resource):
@validate_app_token
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run") api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>") api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop") api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

@ -78,19 +78,15 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = ( voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource): class ChatApi(WebApiResource):
def post(self, app_model, end_user): def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource): class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id): def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource): class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource): class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id): def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)

@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code):
if not source or source != "sso": if not source or source != "sso":
raise WebSSOAuthRequiredError() raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login # Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled: if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
source = decoded.get("token_source") source = decoded.get("token_source")
if source and source == "sso": if source and source == "sso":

@ -1 +1 @@
import core.moderation.base import core.moderation.base

@ -25,17 +25,19 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ['wenxin'] _ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None _historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None _agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None _instruction: str = None
_query: str = None _query: str = None
_prompt_messages_tools: list[PromptMessage] = None _prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message, def run(
query: str, self,
inputs: dict[str, str], message: Message,
) -> Union[Generator, LLMResult]: query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
""" """
Run Cot agent application Run Cot agent application
""" """
@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
# check model mode # check model mode
if 'Observation' not in app_generate_entity.model_conf.stop: if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append('Observation') app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config app_config = self.app_config
# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools( self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True function_call_state = True
llm_usage = { llm_usage = {"usage": None}
'usage': None final_answer = ""
}
final_answer = ''
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict['usage'] = usage final_llm_usage_dict["usage"] = usage
else: else:
llm_usage = final_llm_usage_dict['usage'] llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = [] message_file_ids = []
agent_thought = self.create_agent_thought( agent_thought = self.create_agent_thought(
message_id=message.id, message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
) )
if iteration_step > 1: if iteration_step > 1:
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# recalc llm max tokens # recalc llm max tokens
prompt_messages = self._organize_prompt_messages() prompt_messages = self._organize_prompt_messages()
@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm") raise ValueError("failed to invoke llm")
usage_dict = {} usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output( react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response='', agent_response="",
thought='', thought="",
action_str='', action_str="",
observation='', observation="",
action=None, action=None,
) )
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
if iteration_step == 1: if iteration_step == 1:
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
for chunk in react_chunks: for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action): if isinstance(chunk, AgentScratchpadUnit.Action):
@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk( yield LLMResultChunk(
model=self.model_config.model, model=self.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint='', system_fingerprint="",
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
) )
scratchpad.thought = scratchpad.thought.strip( scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad) self._agent_scratchpad.append(scratchpad)
# get llm usage # get llm usage
if 'usage' in usage_dict: if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict['usage']) increase_usage(llm_usage, usage_dict["usage"])
else: else:
usage_dict['usage'] = LLMUsage.empty_usage() usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '', tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={ tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={}, tool_invoke_meta={},
thought=scratchpad.thought, thought=scratchpad.thought,
observation='', observation="",
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=[], messages_ids=[],
llm_usage=usage_dict['usage'] llm_usage=usage_dict["usage"],
) )
if not scratchpad.is_final(): if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
if not scratchpad.action: if not scratchpad.action:
# failed to extract action, return final answer directly # failed to extract action, return final answer directly
final_answer = '' final_answer = ""
else: else:
if scratchpad.action.action_name.lower() == "final answer": if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps( final_answer = json.dumps(scratchpad.action.action_input)
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:
final_answer = f'{scratchpad.action.action_input}' final_answer = f"{scratchpad.action.action_input}"
except json.JSONDecodeError: except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}' final_answer = f"{scratchpad.action.action_input}"
else: else:
function_call_state = True function_call_state = True
# action is tool call, invoke tool # action is tool call, invoke tool
@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name, tool_name=scratchpad.action.action_name,
tool_input={ tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought, thought=scratchpad.thought,
observation={ observation={scratchpad.action.action_name: tool_invoke_response},
scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=message_file_ids, messages_ids=message_file_ids,
llm_usage=usage_dict['usage'] llm_usage=usage_dict["usage"],
) )
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# update prompt tool message # update prompt tool message
for prompt_tool in self._prompt_messages_tools: for prompt_tool in self._prompt_messages_tools:
@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model, model=model_instance.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
), ),
system_fingerprint='' system_fingerprint="",
) )
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name='', tool_name="",
tool_input={}, tool_input={},
tool_invoke_meta={}, tool_invoke_meta={},
thought=final_answer, thought=final_answer,
observation={}, observation={},
answer=final_answer, answer=final_answer,
messages_ids=[] messages_ids=[],
) )
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( self.queue_manager.publish(
model=model_instance.model, QueueMessageEndEvent(
prompt_messages=prompt_messages, llm_result=LLMResult(
message=AssistantPromptMessage( model=model_instance.model,
content=final_answer prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
), ),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), PublishFrom.APPLICATION_MANAGER,
system_fingerprint='' )
)), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action(
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, self,
tool_instances: dict[str, Tool], action: AgentScratchpadUnit.Action,
message_file_ids: list[str], tool_instances: dict[str, Tool],
trace_manager: Optional[TraceQueueManager] = None message_file_ids: list[str],
) -> tuple[str, ToolInvokeMeta]: trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
""" """
handle invoke action handle invoke action
:param action: action :param action: action
@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id, save_as in message_files:
if save_as: if save_as:
self.variables_pool.set_file( self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish(QueueMessageFileEvent( self.queue_manager.publish(
message_file_id=message_file_id QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id)
@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
convert dict to action convert dict to action
""" """
return AgentScratchpadUnit.Action( return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
action_name=action['action'],
action_input=action['action_input']
)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
""" """
@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
for key, value in inputs.items(): for key, value in inputs.items():
try: try:
instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e: except Exception as e:
continue continue
@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod @abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
organize prompt messages organize prompt messages
""" """
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
""" """
format assistant message format assistant message
""" """
message = '' message = ""
for scratchpad in agent_scratchpad: for scratchpad in agent_scratchpad:
if scratchpad.is_final(): if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}" message += f"Final Answer: {scratchpad.agent_response}"
@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message return message
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
organize historic prompt messages organize historic prompt messages
""" """
result: list[PromptMessage] = [] result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = [] scratchpads: list[AgentScratchpadUnit] = []
@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad: if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit( current_scratchpad = AgentScratchpadUnit(
agent_response=message.content, agent_response=message.content,
thought=message.content or 'I am thinking about how to help you', thought=message.content or "I am thinking about how to help you",
action_str='', action_str="",
action=None, action=None,
observation=None, observation=None,
) )
@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try: try:
current_scratchpad.action = AgentScratchpadUnit.Action( current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name, action_name=message.tool_calls[0].function.name,
action_input=json.loads( action_input=json.loads(message.tool_calls[0].function.arguments),
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
) )
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except: except:
pass pass
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage): elif isinstance(message, UserPromptMessage):
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage( result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
content=self._format_assistant_message(scratchpads)
))
scratchpads = [] scratchpads = []
current_scratchpad = None current_scratchpad = None
result.append(message) result.append(message)
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage( result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
content=self._format_assistant_message(scratchpads)
))
historic_prompts = AgentHistoryPromptTransform( historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config, model_config=self.model_config,
prompt_messages=current_session_messages or [], prompt_messages=current_session_messages or [],
history_messages=result, history_messages=result,
memory=self.memory memory=self.memory,
).get_prompt() ).get_prompt()
return historic_prompts return historic_prompts

@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt \ system_prompt = (
.replace("{{instruction}}", self._instruction) \ first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt) return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
Organize Organize
""" """
# organize system prompt # organize system prompt
system_message = self._organize_system_prompt() system_message = self._organize_system_prompt()
@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad: if not agent_scratchpad:
assistant_messages = [] assistant_messages = []
else: else:
assistant_message = AssistantPromptMessage(content='') assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}" assistant_message.content += f"Final Answer: {unit.agent_response}"
@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages: if assistant_messages:
# organize historic prompt messages # organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([ historic_messages = self._organize_historic_prompt_messages(
system_message, [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
*query_messages, )
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [ messages = [
system_message, system_message,
*historic_messages, *historic_messages,
*query_messages, *query_messages,
*assistant_messages, *assistant_messages,
UserPromptMessage(content='continue') UserPromptMessage(content="continue"),
] ]
else: else:
# organize historic prompt messages # organize historic prompt messages

@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ system_prompt = (
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages # organize current assistant messages
agent_scratchpad = self._agent_scratchpad agent_scratchpad = self._agent_scratchpad
assistant_prompt = '' assistant_prompt = ""
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}" assistant_prompt += f"Final Answer: {unit.agent_response}"
@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}" query_prompt = f"Question: {self._query}"
# join all messages # join all messages
prompt = system_prompt \ prompt = (
.replace("{{historic_messages}}", historic_prompt) \ system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt) \ .replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt) .replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)] return [UserPromptMessage(content=prompt)]

@ -20,6 +20,7 @@ class AgentPromptEntity(BaseModel):
""" """
Agent Prompt Entity. Agent Prompt Entity.
""" """
first_prompt: str first_prompt: str
next_iteration: str next_iteration: str
@ -33,6 +34,7 @@ class AgentScratchpadUnit(BaseModel):
""" """
Action Entity. Action Entity.
""" """
action_name: str action_name: str
action_input: Union[dict, str] action_input: Union[dict, str]
@ -41,8 +43,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary. Convert to dictionary.
""" """
return { return {
'action': self.action_name, "action": self.action_name,
'action_input': self.action_input, "action_input": self.action_input,
} }
agent_response: Optional[str] = None agent_response: Optional[str] = None
@ -56,10 +58,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final. Check if the scratchpad unit is final.
""" """
return self.action is None or ( return self.action is None or (
'final' in self.action.action_name.lower() and "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
'answer' in self.action.action_name.lower()
) )
class AgentEntity(BaseModel): class AgentEntity(BaseModel):
""" """
Agent Entity. Agent Entity.
@ -69,8 +71,9 @@ class AgentEntity(BaseModel):
""" """
Agent Strategy. Agent Strategy.
""" """
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling' CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str provider: str
model: str model: str

@ -24,11 +24,9 @@ from models.model import Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, class FunctionCallAgentRunner(BaseAgentRunner):
message: Message, query: str, **kwargs: Any def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
) -> Generator[LLMResultChunk, None, None]:
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
llm_usage = { llm_usage = {"usage": None}
'usage': None final_answer = ""
}
final_answer = ''
# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict['usage'] = usage final_llm_usage_dict["usage"] = usage
else: else:
llm_usage = final_llm_usage_dict['usage'] llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = [] message_file_ids = []
agent_thought = self.create_agent_thought( agent_thought = self.create_agent_thought(
message_id=message.id, message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
) )
# recalc llm max tokens # recalc llm max tokens
@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = [] tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response # save full response
response = '' response = ""
# save tool call names and inputs # save tool call names and inputs
tool_call_names = '' tool_call_names = ""
tool_call_inputs = '' tool_call_inputs = ""
current_llm_usage = None current_llm_usage = None
@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
is_first_chunk = False is_first_chunk = False
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk): if self.check_tool_calls(chunk):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk)) tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps(
tool_call[1]: tool_call[2] for tool_call in tool_calls {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
}, ensure_ascii=False) )
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content: if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list): if isinstance(chunk.delta.message.content, list):
@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result): if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result)) tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps(
tool_call[1]: tool_call[2] for tool_call in tool_calls {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
}, ensure_ascii=False) )
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if result.usage: if result.usage:
increase_usage(llm_usage, result.usage) increase_usage(llm_usage, result.usage)
@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content response += result.message.content
if not result.message.content: if not result.message.content:
result.message.content = '' result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk( yield LLMResultChunk(
model=model_instance.model, model=model_instance.model,
prompt_messages=result.prompt_messages, prompt_messages=result.prompt_messages,
@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0, index=0,
message=result.message, message=result.message,
usage=result.usage, usage=result.usage,
) ),
) )
assistant_message = AssistantPromptMessage( assistant_message = AssistantPromptMessage(content="", tool_calls=[])
content='',
tool_calls=[]
)
if tool_calls: if tool_calls:
assistant_message.tool_calls=[ assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall( AssistantPromptMessage.ToolCall(
id=tool_call[0], id=tool_call[0],
type='function', type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction( function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
arguments=json.dumps(tool_call[2], ensure_ascii=False) ),
) )
) for tool_call in tool_calls for tool_call in tool_calls
] ]
else: else:
assistant_message.content = response assistant_message.content = response
self._current_thoughts.append(assistant_message) self._current_thoughts.append(assistant_message)
# save thought # save thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=tool_call_names, tool_name=tool_call_names,
tool_input=tool_call_inputs, tool_input=tool_call_inputs,
thought=response, thought=response,
@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None, observation=None,
answer=response, answer=response,
messages_ids=[], messages_ids=[],
llm_usage=current_llm_usage llm_usage=current_llm_usage,
) )
self.queue_manager.publish(QueueAgentThoughtEvent( self.queue_manager.publish(
agent_thought_id=agent_thought.id QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
final_answer += response + '\n' final_answer += response + "\n"
# call tools # call tools
tool_responses = [] tool_responses = []
@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}", "tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
} }
else: else:
# invoke tool # invoke tool
@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish(QueueMessageFileEvent( self.queue_manager.publish(
message_file_id=message_file_id QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER) )
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id)
tool_response = { tool_response = {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": tool_invoke_response, "tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict() "meta": tool_invoke_meta.to_dict(),
} }
tool_responses.append(tool_response) tool_responses.append(tool_response)
if tool_response['tool_response'] is not None: if tool_response["tool_response"] is not None:
self._current_thoughts.append( self._current_thoughts.append(
ToolPromptMessage( ToolPromptMessage(
content=tool_response['tool_response'], content=tool_response["tool_response"],
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name=tool_call_name, name=tool_call_name,
) )
) )
if len(tool_responses) > 0: if len(tool_responses) > 0:
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=None, tool_name=None,
tool_input=None, tool_input=None,
thought=None, thought=None,
tool_invoke_meta={ tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta'] tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
for tool_response in tool_responses
}, },
observation={ observation={
tool_response['tool_call_name']: tool_response['tool_response'] tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses for tool_response in tool_responses
}, },
answer=None, answer=None,
messages_ids=message_file_ids messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
) )
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool # update prompt tool
for prompt_tool in prompt_messages_tools: for prompt_tool in prompt_messages_tools:
@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( self.queue_manager.publish(
model=model_instance.model, QueueMessageEndEvent(
prompt_messages=prompt_messages, llm_result=LLMResult(
message=AssistantPromptMessage( model=model_instance.model,
content=final_answer prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
), ),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), PublishFrom.APPLICATION_MANAGER,
system_fingerprint='' )
)), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
""" """
@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls: if llm_result_chunk.delta.message.tool_calls:
return True return True
return False return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
""" """
Check if there is any blocking tool call in llm result Check if there is any blocking tool call in llm result
@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True return True
return False return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract tool calls from llm result chunk Extract tool calls from llm result chunk
@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = [] tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls: for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {} args = {}
if prompt_message.function.arguments != '': if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments) args = json.loads(prompt_message.function.arguments)
tool_calls.append(( tool_calls.append(
prompt_message.id, (
prompt_message.function.name, prompt_message.id,
args, prompt_message.function.name,
)) args,
)
)
return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract blocking tool calls from llm result Extract blocking tool calls from llm result
@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = [] tool_calls = []
for prompt_message in llm_result.message.tool_calls: for prompt_message in llm_result.message.tool_calls:
args = {} args = {}
if prompt_message.function.arguments != '': if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments) args = json.loads(prompt_message.function.arguments)
tool_calls.append(( tool_calls.append(
prompt_message.id, (
prompt_message.function.name, prompt_message.id,
args, prompt_message.function.name,
)) args,
)
)
return tool_calls return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
Initialize system message Initialize system message
""" """
@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [ return [
SystemPromptMessage(content=prompt_template), SystemPromptMessage(content=prompt_template),
] ]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query)) prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
As for now, gpt supports both fc and vision at the first iteration. As for now, gpt supports both fc and vision at the first iteration.
@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([ prompt_message.content = "\n".join(
content.data if content.type == PromptMessageContentType.TEXT else [
'[image]' if content.type == PromptMessageContentType.IMAGE else content.data
'[file]' if content.type == PromptMessageContentType.TEXT
for content in prompt_message.content else "[image]"
]) if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages return prompt_messages
def _organize_prompt_messages(self): def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or '' prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, []) query_prompt_messages = self._organize_user_query(self.query, [])
@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config, model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts], prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages, history_messages=self.history_prompt_messages,
memory=self.memory memory=self.memory,
).get_prompt() ).get_prompt()
prompt_messages = [ prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0: if len(self._current_thoughts) != 0:
# clear messages after the first iteration # clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser: class CotAgentOutputParser:
@classmethod @classmethod
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ def handle_react_stream_output(
Generator[Union[str, AgentScratchpadUnit.Action], None, None]: cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str): def parse_action(json_str):
try: try:
action = json.loads(json_str) action = json.loads(json_str)
@ -22,7 +23,7 @@ class CotAgentOutputParser:
action = action[0] action = action[0]
for key, value in action.items(): for key, value in action.items():
if 'input' in key.lower(): if "input" in key.lower():
action_input = value action_input = value
else: else:
action_name = value action_name = value
@ -33,37 +34,37 @@ class CotAgentOutputParser:
action_input=action_input, action_input=action_input,
) )
else: else:
return json_str or '' return json_str or ""
except: except:
return json_str or '' return json_str or ""
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks: if not code_blocks:
return return
for block in code_blocks: for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
yield parse_action(json_text) yield parse_action(json_text)
code_block_cache = '' code_block_cache = ""
code_block_delimiter_count = 0 code_block_delimiter_count = 0
in_code_block = False in_code_block = False
json_cache = '' json_cache = ""
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
got_json = False got_json = False
action_cache = '' action_cache = ""
action_str = 'action:' action_str = "action:"
action_idx = 0 action_idx = 0
thought_cache = '' thought_cache = ""
thought_str = 'thought:' thought_str = "thought:"
thought_idx = 0 thought_idx = 0
for response in llm_response: for response in llm_response:
if response.delta.usage: if response.delta.usage:
usage_dict['usage'] = response.delta.usage usage_dict["usage"] = response.delta.usage
response = response.delta.message.content response = response.delta.message.content
if not isinstance(response, str): if not isinstance(response, str):
continue continue
@ -72,24 +73,24 @@ class CotAgentOutputParser:
index = 0 index = 0
while index < len(response): while index < len(response):
steps = 1 steps = 1
delta = response[index:index+steps] delta = response[index : index + steps]
last_character = response[index-1] if index > 0 else '' last_character = response[index - 1] if index > 0 else ""
if delta == '`': if delta == "`":
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count += 1 code_block_delimiter_count += 1
else: else:
if not in_code_block: if not in_code_block:
if code_block_delimiter_count > 0: if code_block_delimiter_count > 0:
yield code_block_cache yield code_block_cache
code_block_cache = '' code_block_cache = ""
else: else:
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block and not in_json: if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0: if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ['\n', ' ', '']: if last_character not in {"\n", " ", ""}:
index += steps index += steps
yield delta yield delta
continue continue
@ -97,7 +98,7 @@ class CotAgentOutputParser:
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
action_cache = '' action_cache = ""
action_idx = 0 action_idx = 0
index += steps index += steps
continue continue
@ -105,18 +106,18 @@ class CotAgentOutputParser:
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
action_cache = '' action_cache = ""
action_idx = 0 action_idx = 0
index += steps index += steps
continue continue
else: else:
if action_cache: if action_cache:
yield action_cache yield action_cache
action_cache = '' action_cache = ""
action_idx = 0 action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ['\n', ' ', '']: if last_character not in {"\n", " ", ""}:
index += steps index += steps
yield delta yield delta
continue continue
@ -124,7 +125,7 @@ class CotAgentOutputParser:
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
thought_cache = '' thought_cache = ""
thought_idx = 0 thought_idx = 0
index += steps index += steps
continue continue
@ -132,31 +133,31 @@ class CotAgentOutputParser:
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
thought_cache = '' thought_cache = ""
thought_idx = 0 thought_idx = 0
index += steps index += steps
continue continue
else: else:
if thought_cache: if thought_cache:
yield thought_cache yield thought_cache
thought_cache = '' thought_cache = ""
thought_idx = 0 thought_idx = 0
if code_block_delimiter_count == 3: if code_block_delimiter_count == 3:
if in_code_block: if in_code_block:
yield from extra_json_from_code_block(code_block_cache) yield from extra_json_from_code_block(code_block_cache)
code_block_cache = '' code_block_cache = ""
in_code_block = not in_code_block in_code_block = not in_code_block
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block: if not in_code_block:
# handle single json # handle single json
if delta == '{': if delta == "{":
json_quote_count += 1 json_quote_count += 1
in_json = True in_json = True
json_cache += delta json_cache += delta
elif delta == '}': elif delta == "}":
json_cache += delta json_cache += delta
if json_quote_count > 0: if json_quote_count > 0:
json_quote_count -= 1 json_quote_count -= 1
@ -172,12 +173,12 @@ class CotAgentOutputParser:
if got_json: if got_json:
got_json = False got_json = False
yield parse_action(json_cache) yield parse_action(json_cache)
json_cache = '' json_cache = ""
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
if not in_code_block and not in_json: if not in_code_block and not in_json:
yield delta.replace('`', '') yield delta.replace("`", "")
index += steps index += steps
@ -186,4 +187,3 @@ class CotAgentOutputParser:
if json_cache: if json_cache:
yield parse_action(json_cache) yield parse_action(json_cache)

@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
{{historic_messages}} {{historic_messages}}
Question: {{query}} Question: {{query}}
{{agent_scratchpad}} {{agent_scratchpad}}
Thought:""" Thought:""" # noqa: E501
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:""" Thought:"""
@ -86,19 +87,20 @@ Action:
``` ```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
""" """ # noqa: E501
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = { REACT_PROMPT_TEMPLATES = {
'english': { "english": {
'chat': { "chat": {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
}, },
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
} }
} }

@ -26,34 +26,24 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items()) config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures() additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert( additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
) )
additional_features.opening_statement, additional_features.suggested_questions = \ additional_features.opening_statement, additional_features.suggested_questions = (
OpeningStatementConfigManager.convert( OpeningStatementConfigManager.convert(config=config_dict)
config=config_dict )
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict config=config_dict
) )
additional_features.more_like_this = MoreLikeThisConfigManager.convert( additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert( additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert( additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
config=config_dict
)
return additional_features return additional_features

@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager: class SensitiveWordAvoidanceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict: if not sensitive_word_avoidance_dict:
return None return None
if sensitive_word_avoidance_dict.get('enabled'): if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity( return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'), type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get('config'), config=sensitive_word_avoidance_dict.get("config"),
) )
else: else:
return None return None
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ def validate_and_set_defaults(
-> tuple[dict, list[str]]: cls, tenant_id, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"): if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = { config["sensitive_word_avoidance"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict): if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type") raise ValueError("sensitive_word_avoidance must be of dict type")
@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"] typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config( ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
return config, ["sensitive_word_avoidance"] return config, ["sensitive_word_avoidance"]

@ -12,67 +12,70 @@ class AgentConfigManager:
:param config: model config args :param config: model config args
""" """
if 'agent_mode' in config and config['agent_mode'] \ if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
and 'enabled' in config['agent_mode']: agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
agent_dict = config.get('agent_mode', {}) if agent_strategy == "function_call":
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react': elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else: else:
# old configs, try to detect default strategy # old configs, try to detect default strategy
if config['model']['provider'] == 'openai': if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
else: else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = [] agent_tools = []
for tool in agent_dict.get('tools', []): for tool in agent_dict.get("tools", []):
keys = tool.keys() keys = tool.keys()
if len(keys) >= 4: if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]: if "enabled" not in tool or not tool["enabled"]:
continue continue
agent_tool_properties = { agent_tool_properties = {
'provider_type': tool['provider_type'], "provider_type": tool["provider_type"],
'provider_id': tool['provider_id'], "provider_id": tool["provider_id"],
'tool_name': tool['tool_name'], "tool_name": tool["tool_name"],
'tool_parameters': tool.get('tool_parameters', {}) "tool_parameters": tool.get("tool_parameters", {}),
} }
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity(**agent_tool_properties))
if 'strategy' in config['agent_mode'] and \ if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
config['agent_mode']['strategy'] not in ['react_router', 'router']: "react_router",
agent_prompt = agent_dict.get('prompt', None) or {} "router",
}:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode # check model mode
model_mode = config.get('model', {}).get('mode', 'completion') model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == 'completion': if model_mode == "completion":
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', first_prompt=agent_prompt.get(
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
next_iteration=agent_prompt.get('next_iteration', ),
REACT_PROMPT_TEMPLATES['english']['completion'][ next_iteration=agent_prompt.get(
'agent_scratchpad']), "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
),
) )
else: else:
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', first_prompt=agent_prompt.get(
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
next_iteration=agent_prompt.get('next_iteration', ),
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
) )
return AgentEntity( return AgentEntity(
provider=config['model']['provider'], provider=config["model"]["provider"],
model=config['model']['name'], model=config["model"]["name"],
strategy=strategy, strategy=strategy,
prompt=agent_prompt_entity, prompt=agent_prompt_entity,
tools=agent_tools, tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5) max_iteration=agent_dict.get("max_iteration", 5),
) )
return None return None

@ -15,39 +15,38 @@ class DatasetConfigManager:
:param config: model config args :param config: model config args
""" """
dataset_ids = [] dataset_ids = []
if 'datasets' in config.get('dataset_configs', {}): if "datasets" in config.get("dataset_configs", {}):
datasets = config.get('dataset_configs', {}).get('datasets', { datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get('datasets', []): for dataset in datasets.get("datasets", []):
keys = list(dataset.keys()) keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset': if len(keys) == 0 or keys[0] != "dataset":
continue continue
dataset = dataset['dataset'] dataset = dataset["dataset"]
if 'enabled' not in dataset or not dataset['enabled']: if "enabled" not in dataset or not dataset["enabled"]:
continue continue
dataset_id = dataset.get('id', None) dataset_id = dataset.get("id", None)
if dataset_id: if dataset_id:
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if 'agent_mode' in config and config['agent_mode'] \ if (
and 'enabled' in config['agent_mode'] \ "agent_mode" in config
and config['agent_mode']['enabled']: and config["agent_mode"]
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
agent_dict = config.get('agent_mode', {}) for tool in agent_dict.get("tools", []):
for tool in agent_dict.get('tools', []):
keys = tool.keys() keys = tool.keys()
if len(keys) == 1: if len(keys) == 1:
# old standard # old standard
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key != 'dataset': if key != "dataset":
continue continue
tool_item = tool[key] tool_item = tool[key]
@ -55,30 +54,28 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]: if "enabled" not in tool_item or not tool_item["enabled"]:
continue continue
dataset_id = tool_item['id'] dataset_id = tool_item["id"]
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if len(dataset_ids) == 0: if len(dataset_ids) == 0:
return None return None
# dataset configs # dataset configs
if 'dataset_configs' in config and config.get('dataset_configs'): if "dataset_configs" in config and config.get("dataset_configs"):
dataset_configs = config.get('dataset_configs') dataset_configs = config.get("dataset_configs")
else: else:
dataset_configs = { dataset_configs = {"retrieval_model": "multiple"}
'retrieval_model': 'multiple' query_variable = config.get("dataset_query_variable")
}
query_variable = config.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single': if dataset_configs["retrieval_model"] == "single":
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model'] dataset_configs["retrieval_model"]
) ),
) ),
) )
else: else:
return DatasetEntity( return DatasetEntity(
@ -86,15 +83,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model'] dataset_configs["retrieval_model"]
), ),
top_k=dataset_configs.get('top_k', 4), top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get('score_threshold'), score_threshold=dataset_configs.get("score_threshold"),
reranking_model=dataset_configs.get('reranking_model'), reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get('weights'), weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get('reranking_enabled', True), reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
) ),
) )
@classmethod @classmethod
@ -111,13 +108,10 @@ class DatasetConfigManager:
# dataset_configs # dataset_configs
if not config.get("dataset_configs"): if not config.get("dataset_configs"):
config["dataset_configs"] = {'retrieval_model': 'single'} config["dataset_configs"] = {"retrieval_model": "single"}
if not config["dataset_configs"].get("datasets"): if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = { config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
"strategy": "router",
"datasets": []
}
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
@ -125,8 +119,9 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = (config.get("dataset_configs") need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
and config["dataset_configs"].get("datasets", {}).get("datasets")) "datasets", {}
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion
@ -148,10 +143,7 @@ class DatasetConfigManager:
""" """
# Extract dataset config for legacy compatibility # Extract dataset config for legacy compatibility
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = { config["agent_mode"] = {"enabled": False, "tools": []}
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -175,7 +167,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False has_datasets = False
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]: for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key == "dataset": if key == "dataset":
@ -188,7 +180,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool): if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if 'id' not in tool_item: if "id" not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
try: try:

@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -25,9 +23,7 @@ class ModelConfigConverter:
provider_manager = ProviderManager() provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle( provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
provider=model_config.provider,
model_type=ModelType.LLM
) )
provider_name = provider_model_bundle.configuration.provider.provider provider_name = provider_model_bundle.configuration.provider.provider
@ -38,8 +34,7 @@ class ModelConfigConverter:
# check model credentials # check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials( model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model_type=ModelType.LLM, model=model_config.model
model=model_config.model
) )
if model_credentials is None: if model_credentials is None:
@ -51,8 +46,7 @@ class ModelConfigConverter:
if not skip_check: if not skip_check:
# check model # check model
provider_model = provider_model_bundle.configuration.get_provider_model( provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model=model_config.model, model_type=ModelType.LLM
model_type=ModelType.LLM
) )
if provider_model is None: if provider_model is None:
@ -69,24 +63,18 @@ class ModelConfigConverter:
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
stop = [] stop = []
if 'stop' in completion_params: if "stop" in completion_params:
stop = completion_params['stop'] stop = completion_params["stop"]
del completion_params['stop'] del completion_params["stop"]
# get model mode # get model mode
model_mode = model_config.mode model_mode = model_config.mode
if not model_mode: if not model_mode:
mode_enum = model_type_instance.get_model_mode( mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema( model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_config.model,
model_credentials
)
if not skip_check and not model_schema: if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")

@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args :param config: model config args
""" """
# model config # model config
model_config = config.get('model') model_config = config.get("model")
if not model_config: if not model_config:
raise ValueError("model is required") raise ValueError("model is required")
completion_params = model_config.get('completion_params') completion_params = model_config.get("completion_params")
stop = [] stop = []
if 'stop' in completion_params: if "stop" in completion_params:
stop = completion_params['stop'] stop = completion_params["stop"]
del completion_params['stop'] del completion_params["stop"]
# get model mode # get model mode
model_mode = model_config.get('mode') model_mode = model_config.get("mode")
return ModelConfigEntity( return ModelConfigEntity(
provider=config['model']['provider'], provider=config["model"]["provider"],
model=config['model']['name'], model=config["model"]["name"],
mode=model_mode, mode=model_mode,
parameters=completion_params, parameters=completion_params,
stop=stop, stop=stop,
@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id :param tenant_id: tenant id
:param config: app model config args :param config: app model config args
""" """
if 'model' not in config: if "model" not in config:
raise ValueError("model is required") raise ValueError("model is required")
if not isinstance(config["model"], dict): if not isinstance(config["model"], dict):
@ -52,17 +52,16 @@ class ModelConfigManager:
# model.provider # model.provider
provider_entities = model_provider_factory.get_providers() provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities] model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name # model.name
if 'name' not in config["model"]: if "name" not in config["model"]:
raise ValueError("model.name is required") raise ValueError("model.name is required")
provider_manager = ProviderManager() provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models( models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], provider=config["model"]["provider"], model_type=ModelType.LLM
model_type=ModelType.LLM
) )
if not models: if not models:
@ -80,12 +79,12 @@ class ModelConfigManager:
# model.mode # model.mode
if model_mode: if model_mode:
config['model']["mode"] = model_mode config["model"]["mode"] = model_mode
else: else:
config['model']["mode"] = "completion" config["model"]["mode"] = "completion"
# model.completion_params # model.completion_params
if 'completion_params' not in config["model"]: if "completion_params" not in config["model"]:
raise ValueError("model.completion_params is required") raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params( config["model"]["completion_params"] = cls.validate_model_completion_params(
@ -101,7 +100,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type") raise ValueError("model.completion_params must be of object type")
# stop # stop
if 'stop' not in cp: if "stop" not in cp:
cp["stop"] = [] cp["stop"] = []
elif not isinstance(cp["stop"], list): elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type") raise ValueError("stop in model.completion_params must be of list type")

@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"): if not config.get("prompt_type"):
raise ValueError("prompt_type is required") raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "") simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity( return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else: else:
advanced_chat_prompt_template = None advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {}) chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config: if chat_prompt_config:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({ chat_prompt_messages.append(
"text": message["text"], {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
"role": PromptMessageRole.value_of(message["role"]) )
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {}) completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config: if completion_prompt_config:
completion_prompt_template_params = { completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'], "prompt": completion_prompt_config["prompt"]["text"],
} }
if 'conversation_histories_role' in completion_prompt_config: if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = { completion_prompt_template_params["role_prefix"] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'], "user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
} }
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity( return PromptTemplateEntity(
prompt_type=prompt_type, prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template, advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template advanced_completion_prompt_template=advanced_completion_prompt_template,
) )
@classmethod @classmethod
@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config['prompt_type'] not in prompt_type_vals: if config["prompt_type"] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}") raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config # chat_prompt_config
@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict): if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type") raise ValueError("completion_prompt_config must be of object type")
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']: if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError("chat_prompt_config or completion_prompt_config is required " raise ValueError(
"when prompt_type is advanced") "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
)
model_mode_vals = [mode.value for mode in ModelMode] model_mode_vals = [mode.value for mode in ModelMode]
if config['model']["mode"] not in model_mode_vals: if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
if not user_prefix: if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
if not assistant_prefix: if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config['model']["mode"] == ModelMode.CHAT.value: if config["model"]["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt'] prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10: if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10") raise ValueError("prompt messages must be less than 10")

@ -16,51 +16,49 @@ class BasicVariablesConfigManager:
variable_entities = [] variable_entities = []
# old external_data_tools # old external_data_tools
external_data_tools = config.get('external_data_tools', []) external_data_tools = config.get("external_data_tools", [])
for external_data_tool in external_data_tools: for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']: if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=external_data_tool['variable'], variable=external_data_tool["variable"],
type=external_data_tool['type'], type=external_data_tool["type"],
config=external_data_tool['config'] config=external_data_tool["config"],
) )
) )
# variables and external_data_tools # variables and external_data_tools
for variables in config.get('user_input_form', []): for variables in config.get("user_input_form", []):
variable_type = list(variables.keys())[0] variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type] variable = variables[variable_type]
if 'config' not in variable: if "config" not in variable:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=variable['variable'], variable=variable["variable"], type=variable["type"], config=variable["config"]
type=variable['type'],
config=variable['config']
) )
) )
elif variable_type in [ elif variable_type in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER, VariableEntityType.NUMBER,
VariableEntityType.SELECT, VariableEntityType.SELECT,
]: }:
variable = variables[variable_type] variable = variables[variable_type]
variable_entities.append( variable_entities.append(
VariableEntity( VariableEntity(
type=variable_type, type=variable_type,
variable=variable.get('variable'), variable=variable.get("variable"),
description=variable.get('description'), description=variable.get("description"),
label=variable.get('label'), label=variable.get("label"),
required=variable.get('required', False), required=variable.get("required", False),
max_length=variable.get('max_length'), max_length=variable.get("max_length"),
options=variable.get('options'), options=variable.get("options"),
default=variable.get('default'), default=variable.get("default"),
) )
) )
@ -99,17 +97,17 @@ class BasicVariablesConfigManager:
variables = [] variables = []
for item in config["user_input_form"]: for item in config["user_input_form"]:
key = list(item.keys())[0] key = list(item.keys())[0]
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key] form_item = item[key]
if 'label' not in form_item: if "label" not in form_item:
raise ValueError("label is required in user_input_form") raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str): if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type") raise ValueError("label in user_input_form must be of string type")
if 'variable' not in form_item: if "variable" not in form_item:
raise ValueError("variable is required in user_input_form") raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str): if not isinstance(form_item["variable"], str):
@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None: if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, " raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
"and cannot start with a number")
variables.append(form_item["variable"]) variables.append(form_item["variable"])
if 'required' not in form_item or not form_item["required"]: if "required" not in form_item or not form_item["required"]:
form_item["required"] = False form_item["required"] = False
if not isinstance(form_item["required"], bool): if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type") raise ValueError("required in user_input_form must be of boolean type")
if key == "select": if key == "select":
if 'options' not in form_item or not form_item["options"]: if "options" not in form_item or not form_item["options"]:
form_item["options"] = [] form_item["options"] = []
if not isinstance(form_item["options"], list): if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings") raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item['default'] \ if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list") raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"] return config, ["user_input_form"]
@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
typ = tool["type"] typ = tool["type"]
config = tool["config"] config = tool["config"]
ExternalDataToolFactory.validate_config( ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"] return config, ["external_data_tools"]

@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
""" """
Model Config Entity. Model Config Entity.
""" """
provider: str provider: str
model: str model: str
mode: Optional[str] = None mode: Optional[str] = None
@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
""" """
Advanced Chat Message Entity. Advanced Chat Message Entity.
""" """
text: str text: str
role: PromptMessageRole role: PromptMessageRole
@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
""" """
Advanced Chat Prompt Template Entity. Advanced Chat Prompt Template Entity.
""" """
messages: list[AdvancedChatMessageEntity] messages: list[AdvancedChatMessageEntity]
@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
""" """
Role Prefix Entity. Role Prefix Entity.
""" """
user: str user: str
assistant: str assistant: str
@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
Prompt Type. Prompt Type.
'simple', 'advanced' 'simple', 'advanced'
""" """
SIMPLE = 'simple'
ADVANCED = 'advanced' SIMPLE = "simple"
ADVANCED = "advanced"
@classmethod @classmethod
def value_of(cls, value: str) -> 'PromptType': def value_of(cls, value: str) -> "PromptType":
""" """
Get value of given mode. Get value of given mode.
@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid prompt type value {value}') raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType prompt_type: PromptType
simple_prompt_template: Optional[str] = None simple_prompt_template: Optional[str] = None
@ -87,7 +92,7 @@ class VariableEntityType(str, Enum):
SELECT = "select" SELECT = "select"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"
NUMBER = "number" NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool" EXTERNAL_DATA_TOOL = "external_data_tool"
class VariableEntity(BaseModel): class VariableEntity(BaseModel):
@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
""" """
External Data Variable Entity. External Data Variable Entity.
""" """
variable: str variable: str
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = {}
@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy. Dataset Retrieve Strategy.
'single' or 'multiple' 'single' or 'multiple'
""" """
SINGLE = 'single'
MULTIPLE = 'multiple' SINGLE = "single"
MULTIPLE = "multiple"
@classmethod @classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy': def value_of(cls, value: str) -> "RetrieveStrategy":
""" """
Get value of given mode. Get value of given mode.
@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid retrieve strategy value {value}') raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None top_k: Optional[int] = None
score_threshold: Optional[float] = .0 score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = 'reranking_model' rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
weights: Optional[dict] = None weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True reranking_enabled: Optional[bool] = True
class DatasetEntity(BaseModel): class DatasetEntity(BaseModel):
""" """
Dataset Config Entity. Dataset Config Entity.
""" """
dataset_ids: list[str] dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity retrieve_config: DatasetRetrieveConfigEntity
@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
""" """
Sensitive Word Avoidance Entity. Sensitive Word Avoidance Entity.
""" """
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = {}
@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
""" """
Sensitive Word Avoidance Entity. Sensitive Word Avoidance Entity.
""" """
enabled: bool enabled: bool
voice: Optional[str] = None voice: Optional[str] = None
language: Optional[str] = None language: Optional[str] = None
@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
""" """
Tracing Config Entity. Tracing Config Entity.
""" """
enabled: bool enabled: bool
tracing_provider: str tracing_provider: str
class AppAdditionalFeatures(BaseModel): class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None file_upload: Optional[FileExtraConfig] = None
opening_statement: Optional[str] = None opening_statement: Optional[str] = None
@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
text_to_speech: Optional[TextToSpeechEntity] = None text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None trace_config: Optional[TracingConfigEntity] = None
class AppConfig(BaseModel): class AppConfig(BaseModel):
""" """
Application Config Entity. Application Config Entity.
""" """
tenant_id: str tenant_id: str
app_id: str app_id: str
app_mode: AppMode app_mode: AppMode
@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
""" """
App Model Config From. App Model Config From.
""" """
ARGS = 'args'
APP_LATEST_CONFIG = 'app-latest-config' ARGS = "args"
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
class EasyUIBasedAppConfig(AppConfig): class EasyUIBasedAppConfig(AppConfig):
""" """
Easy UI Based App Config Entity. Easy UI Based App Config Entity.
""" """
app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str app_model_config_id: str
app_model_config_dict: dict app_model_config_dict: dict
@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
""" """
Workflow UI Based App Config Entity. Workflow UI Based App Config Entity.
""" """
workflow_id: str workflow_id: str

@ -13,21 +13,19 @@ class FileUploadConfigManager:
:param config: model config args :param config: model config args
:param is_vision: if True, the feature is vision feature :param is_vision: if True, the feature is vision feature
""" """
file_upload_dict = config.get('file_upload') file_upload_dict = config.get("file_upload")
if file_upload_dict: if file_upload_dict:
if file_upload_dict.get('image'): if file_upload_dict.get("image"):
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = { image_config = {
'number_limits': file_upload_dict['image']['number_limits'], "number_limits": file_upload_dict["image"]["number_limits"],
'transfer_methods': file_upload_dict['image']['transfer_methods'] "transfer_methods": file_upload_dict["image"]["transfer_methods"],
} }
if is_vision: if is_vision:
image_config['detail'] = file_upload_dict['image']['detail'] image_config["detail"] = file_upload_dict["image"]["detail"]
return FileExtraConfig( return FileExtraConfig(image_config=image_config)
image_config=image_config
)
return None return None
@ -49,21 +47,21 @@ class FileUploadConfigManager:
if not config["file_upload"].get("image"): if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False} config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']: if config["file_upload"]["image"]["enabled"]:
number_limits = config['file_upload']['image']['number_limits'] number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6: if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]") raise ValueError("number_limits must be in [1, 6]")
if is_vision: if is_vision:
detail = config['file_upload']['image']['detail'] detail = config["file_upload"]["image"]["detail"]
if detail not in ['high', 'low']: if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']") raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods'] transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list): if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type") raise ValueError("transfer_methods must be of list type")
for method in transfer_methods: for method in transfer_methods:
if method not in ['remote_url', 'local_file']: if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"] return config, ["file_upload"]

@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
:param config: model config args :param config: model config args
""" """
more_like_this = False more_like_this = False
more_like_this_dict = config.get('more_like_this') more_like_this_dict = config.get("more_like_this")
if more_like_this_dict: if more_like_this_dict:
if more_like_this_dict.get('enabled'): if more_like_this_dict.get("enabled"):
more_like_this = True more_like_this = True
return more_like_this return more_like_this
@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("more_like_this"): if not config.get("more_like_this"):
config["more_like_this"] = { config["more_like_this"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["more_like_this"], dict): if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type") raise ValueError("more_like_this must be of dict type")

@ -1,5 +1,3 @@
class OpeningStatementConfigManager: class OpeningStatementConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> tuple[str, list]: def convert(cls, config: dict) -> tuple[str, list]:
@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
:param config: model config args :param config: model config args
""" """
# opening statement # opening statement
opening_statement = config.get('opening_statement') opening_statement = config.get("opening_statement")
# suggested questions # suggested questions
suggested_questions_list = config.get('suggested_questions') suggested_questions_list = config.get("suggested_questions")
return opening_statement, suggested_questions_list return opening_statement, suggested_questions_list

@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict) -> bool:
show_retrieve_source = False show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource') retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict: if retriever_resource_dict:
if retriever_resource_dict.get('enabled'): if retriever_resource_dict.get("enabled"):
show_retrieve_source = True show_retrieve_source = True
return show_retrieve_source return show_retrieve_source
@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("retriever_resource"): if not config.get("retriever_resource"):
config["retriever_resource"] = { config["retriever_resource"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["retriever_resource"], dict): if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type") raise ValueError("retriever_resource must be of dict type")

@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
:param config: model config args :param config: model config args
""" """
speech_to_text = False speech_to_text = False
speech_to_text_dict = config.get('speech_to_text') speech_to_text_dict = config.get("speech_to_text")
if speech_to_text_dict: if speech_to_text_dict:
if speech_to_text_dict.get('enabled'): if speech_to_text_dict.get("enabled"):
speech_to_text = True speech_to_text = True
return speech_to_text return speech_to_text
@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("speech_to_text"): if not config.get("speech_to_text"):
config["speech_to_text"] = { config["speech_to_text"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["speech_to_text"], dict): if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type") raise ValueError("speech_to_text must be of dict type")

@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: model config args :param config: model config args
""" """
suggested_questions_after_answer = False suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
if suggested_questions_after_answer_dict: if suggested_questions_after_answer_dict:
if suggested_questions_after_answer_dict.get('enabled'): if suggested_questions_after_answer_dict.get("enabled"):
suggested_questions_after_answer = True suggested_questions_after_answer = True
return suggested_questions_after_answer return suggested_questions_after_answer
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("suggested_questions_after_answer"): if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = { config["suggested_questions_after_answer"] = {"enabled": False}
"enabled": False
}
if not isinstance(config["suggested_questions_after_answer"], dict): if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type") raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not \ if (
config["suggested_questions_after_answer"]["enabled"]: "enabled" not in config["suggested_questions_after_answer"]
or not config["suggested_questions_after_answer"]["enabled"]
):
config["suggested_questions_after_answer"]["enabled"] = False config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
:param config: model config args :param config: model config args
""" """
text_to_speech = None text_to_speech = None
text_to_speech_dict = config.get('text_to_speech') text_to_speech_dict = config.get("text_to_speech")
if text_to_speech_dict: if text_to_speech_dict:
if text_to_speech_dict.get('enabled'): if text_to_speech_dict.get("enabled"):
text_to_speech = TextToSpeechEntity( text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'), enabled=text_to_speech_dict.get("enabled"),
voice=text_to_speech_dict.get('voice'), voice=text_to_speech_dict.get("voice"),
language=text_to_speech_dict.get('language'), language=text_to_speech_dict.get("language"),
) )
return text_to_speech return text_to_speech
@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("text_to_speech"): if not config.get("text_to_speech"):
config["text_to_speech"] = { config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
"enabled": False,
"voice": "",
"language": ""
}
if not isinstance(config["text_to_speech"], dict): if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type") raise ValueError("text_to_speech must be of dict type")

@ -1,4 +1,3 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
""" """
Advanced Chatbot App Config Entity. Advanced Chatbot App Config Entity.
""" """
pass pass
class AdvancedChatAppConfigManager(BaseAppConfigManager): class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id, app_id=app_model.id,
app_mode=app_mode, app_mode=app_mode,
workflow_id=workflow.id, workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
config=features_dict variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
), additional_features=cls.convert_features(features_dict, app_mode),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
) )
return app_config return app_config
@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, config=config, is_vision=False
is_vision=False
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config) config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
config=config,
only_structure_validate=only_structure_validate
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys} filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config return filtered_config

@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator): class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
@ -54,7 +56,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
@ -63,14 +66,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> Union[dict[str, Any], Generator[dict | str, None, None]]: ... ) -> Union[dict[str, Any], Generator[dict | str, None, None]]: ...
def generate( def generate(
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> dict[str, Any] | Generator[str | dict, None, None]: ) -> dict[str, Any] | Generator[str | dict, None, None]:
""" """
Generate App response. Generate App response.
@ -81,44 +84,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not args.get('query'): if not args.get("query"):
raise ValueError('query is required') raise ValueError("query is required")
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
# get conversation # get conversation
conversation = None conversation = None
conversation_id = args.get('conversation_id') conversation_id = args.get("conversation_id")
if conversation_id: if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) conversation = self._get_conversation_by_user(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config( app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id user_id = user.id if isinstance(user, Account) else user.session_id
@ -140,7 +136,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -150,16 +146,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from, invoke_from=invoke_from,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=conversation, conversation=conversation,
stream=stream stream=stream,
) )
def single_iteration_generate(self, app_model: App, def single_iteration_generate(
workflow: Workflow, self, app_model: App, workflow: Workflow, node_id: str, user: Account | EndUser, args: dict, stream: bool = True
node_id: str, ) -> dict[str, Any] | Generator[str, Any, None]:
user: Account | EndUser,
args: dict,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
""" """
Generate App response. Generate App response.
@ -171,16 +163,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not node_id: if not node_id:
raise ValueError('node_id is required') raise ValueError("node_id is required")
if args.get('inputs') is None: if args.get("inputs") is None:
raise ValueError('inputs is required') raise ValueError("inputs is required")
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config( app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# init application generate entity # init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity( application_generate_entity = AdvancedChatAppGenerateEntity(
@ -188,18 +177,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
conversation_id=None, conversation_id=None,
inputs={}, inputs={},
query='', query="",
files=[], files=[],
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
extras={ extras={"auto_generate_conversation_name": False},
"auto_generate_conversation_name": False
},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, node_id=node_id, inputs=args["inputs"]
inputs=args['inputs'] ),
)
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -209,17 +195,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=None, conversation=None,
stream=stream stream=stream,
) )
def _generate(self, *, def _generate(
workflow: Workflow, self,
user: Union[Account, EndUser], *,
invoke_from: InvokeFrom, workflow: Workflow,
application_generate_entity: AdvancedChatAppGenerateEntity, user: Union[Account, EndUser],
conversation: Optional[Conversation] = None, invoke_from: InvokeFrom,
stream: bool = True) \ application_generate_entity: AdvancedChatAppGenerateEntity,
-> dict[str, Any] | Generator[str, Any, None]: conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
""" """
Generate App response. Generate App response.
@ -235,10 +223,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
is_first_conversation = True is_first_conversation = True
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation: if is_first_conversation:
# update conversation features # update conversation features
@ -253,18 +238,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), # type: ignore target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(), # type: ignore
'conversation_id': conversation.id, "application_generate_entity": application_generate_entity,
'message_id': message.id, "queue_manager": queue_manager,
'context': contextvars.copy_context(), "conversation_id": conversation.id,
}) "message_id": message.id,
"context": contextvars.copy_context(),
},
)
worker_thread.start() worker_thread.start()
@ -278,18 +266,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user, user=user,
stream=stream, stream=stream,
) )
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
def _generate_worker(self, flask_app: Flask, return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, def _generate_worker(
conversation_id: str, self,
message_id: str, flask_app: Flask,
context: contextvars.Context) -> None: application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -312,22 +300,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation, conversation=conversation,
message=message message=message,
) )
runner.run() runner.run()
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == 'true': if os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
@ -373,7 +360,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e

@ -21,14 +21,11 @@ class AudioTrunk:
self.status = status self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace(): if not text_content or text_content.isspace():
return return
return model_instance.invoke_tts( return model_instance.invoke_tts(
content_text=text_content.strip(), content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
user="responding_tts",
tenant_id=tenant_id,
voice=voice
) )
@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
except Exception as e: except Exception as e:
logging.getLogger(__name__).warning(e) logging.getLogger(__name__).warning(e)
break break
audio_queue.put(AudioTrunk("finish", b'')) audio_queue.put(AudioTrunk("finish", b""))
class AppGeneratorTTSPublisher: class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str): def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.msg_text = '' self.msg_text = ""
self._audio_queue = queue.Queue() self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue() self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]') self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager() self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance( self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, tenant_id=self.tenant_id, model_type=ModelType.TTS
model_type=ModelType.TTS
) )
self.voices = self.model_instance.get_tts_voices() self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices] values = [voice.get("value") for voice in self.voices]
self.voice = voice self.voice = voice
if not voice or voice not in values: if not voice or voice not in values:
self.voice = self.voices[0].get('value') self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2 self.MAX_SENTENCE = 2
self._last_audio_event = None self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start() self._runtime_thread = threading.Thread(target=self._runtime).start()
@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get() message = self._msg_queue.get()
if message is None: if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0: if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text, futures_result = self.executor.submit(
self.model_instance, self.tenant_id, self.voice) _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result) future_queue.put(futures_result)
break break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@ -94,28 +90,27 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent): elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent): elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get('output', '') self.msg_text += message.event.outputs.get("output", "")
self.last_message = message self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text) sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1 self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr) text_content = "".join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content, futures_result = self.executor.submit(
self.model_instance, _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
self.tenant_id, )
self.voice)
future_queue.put(futures_result) future_queue.put(futures_result)
if text_tmp: if text_tmp:
self.msg_text = text_tmp self.msg_text = text_tmp
else: else:
self.msg_text = '' self.msg_text = ""
except Exception as e: except Exception as e:
self.logger.warning(e) self.logger.warning(e)
break break
future_queue.put(None) future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None: def check_and_get_audio(self) -> AudioTrunk | None:
try: try:
if self._last_audio_event and self._last_audio_event.status == "finish": if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor: if self.executor:

@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
) )
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
""" """
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message message: Message,
) -> None: ) -> None:
""" """
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -66,14 +66,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError('App not found') raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow: if not workflow:
raise ValueError('Workflow not initialized') raise ValueError("Workflow not initialized")
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user: if end_user:
user_id = end_user.session_id user_id = end_user.session_id
@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs user_inputs=self.application_generate_entity.single_iteration_run.inputs,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation # moderation
if self.handle_input_moderation( if self.handle_input_moderation(
app_record=app_record, app_record=app_record,
app_generate_entity=self.application_generate_entity, app_generate_entity=self.application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=self.message.id message_id=self.message.id,
): ):
return return
# annotation reply # annotation reply
if self.handle_annotation_reply( if self.handle_annotation_reply(
app_record=app_record, app_record=app_record,
message=self.message, message=self.message,
query=query, query=query,
app_generate_entity=self.application_generate_entity app_generate_entity=self.application_generate_entity,
): ):
return return
# Init conversation variables # Init conversation variables
stmt = select(ConversationVariable).where( stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
) )
with Session(db.engine) as session: with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all() conversation_variables = session.scalars(stmt).all()
@ -174,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER else UserFrom.END_USER
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._handle_event(workflow_entry, event) self._handle_event(workflow_entry, event)
def handle_input_moderation( def handle_input_moderation(
self, self,
app_record: App, app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity, app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any], inputs: Mapping[str, Any],
query: str, query: str,
message_id: str message_id: str,
) -> bool: ) -> bool:
""" """
Handle input moderation Handle input moderation
@ -216,19 +217,15 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query, query=query,
message_id=message_id, message_id=message_id,
) )
except ModerationException as e: except ModerationError as e:
self._complete_with_stream_output( self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
text=str(e),
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True return True
return False return False
def handle_annotation_reply(self, app_record: App, def handle_annotation_reply(
message: Message, self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
query: str, ) -> bool:
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
""" """
Handle annotation reply Handle annotation reply
:param app_record: app record :param app_record: app record
@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
) )
if annotation_reply: if annotation_reply:
self._publish_event( self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._complete_with_stream_output( self._complete_with_stream_output(
text=annotation_reply.content, text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
) )
return True return True
return False return False
def _complete_with_stream_output(self, def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
""" """
Direct output Direct output
:param text: text :param text: text
:return: :return:
""" """
self._publish_event( self._publish_event(QueueTextChunkEvent(text=text))
QueueTextChunkEvent(
text=text
)
)
self._publish_event( self._publish_event(QueueStopEvent(stopped_by=stopped_by))
QueueStopEvent(stopped_by=stopped_by)
)

@ -27,15 +27,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id, "conversation_id": blocking_response.data.conversation_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -49,13 +49,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[dict | str, Any, None]: def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -66,14 +68,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -84,7 +86,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield response_chunk yield response_chunk
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[dict | str, Any, None]: def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -95,20 +99,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow _workflow: Workflow
@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow, workflow: Workflow,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" """
Initialize AdvancedChatAppGenerateTaskPipeline. Initialize AdvancedChatAppGenerateTaskPipeline.
@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._conversation, self._application_generate_entity.query
self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response( generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {} extras = {}
if stream_response.metadata: if stream_response.metadata:
extras['metadata'] = stream_response.metadata extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse( return ChatbotAppBlockingResponse(
task_id=stream_response.task_id, task_id=stream_response.task_id,
@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id, message_id=self._message.id,
answer=self._task_state.answer, answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
**extras **extras,
) ),
) )
else: else:
continue continue
raise Exception('Queue listening stopped unexpectedly.') raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -176,32 +176,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id, conversation_id=self._conversation.id,
message_id=self._message.id, message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response stream_response=stream_response,
) )
def _listenAudioMsg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio() audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ def _wrapper_process_stream_response(
Generator[StreamResponse, None, None]: self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ if (
'text_to_speech'].get('autoPlay') == 'enabled': features_dict.get("text_to_speech")
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -214,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
try: try:
if not tts_publisher: if not tts_publisher:
break break
audio_trunk = tts_publisher.checkAndGetAudio() audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None: if audio_trunk is None:
# release cpu # release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
break break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None, tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close() db.session.close()
yield self._workflow_start_to_stream_response( yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start( workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
workflow_run=workflow_run,
event=event
)
response = self._workflow_node_start_to_stream_response( response = self._workflow_node_start_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response( yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response( yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response( yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationNextEvent): elif isinstance(event, QueueIterationNextEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response( yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationCompletedEvent): elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response( yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueWorkflowSucceededEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success( workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run, workflow_run=workflow_run,
@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
) )
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
self._queue_manager.publish( self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run, workflow_run=workflow_run,
@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
) )
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message)) yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
) )
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
# Save message # Save message
@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message() self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit() db.session.commit()
db.session.refresh(self._message) db.session.refresh(self._message)
@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message() self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit() db.session.commit()
db.session.refresh(self._message) db.session.refresh(self._message)
@ -466,13 +451,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(message=queue_message) tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id) yield self._message_to_stream_response(
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer: if output_moderation_answer:
@ -502,8 +489,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer = self._task_state.answer self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ self._message.message_metadata = (
if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if graph_runtime_state and graph_runtime_state.llm_usage: if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage usage = graph_runtime_state.llm_usage
@ -523,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None, is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras extras=self._application_generate_entity.extras,
) )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -533,15 +521,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
extras = {} extras = {}
if self._task_state.metadata: if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy() extras["metadata"] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']: if "annotation_reply" in extras["metadata"]:
del extras['metadata']['annotation_reply'] del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
id=self._message.id,
**extras
) )
def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:
@ -555,14 +541,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output() self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish( self._queue_manager.publish(
QueueTextChunkEvent( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
) )
self._queue_manager.publish( self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:

@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
""" """
Agent Chatbot App Config Entity. Agent Chatbot App Config Entity.
""" """
agent: Optional[AgentEntity] = None agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager): class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(
app_model_config: AppModelConfig, cls,
conversation: Optional[Conversation] = None, app_model: App,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> AgentChatAppConfig:
""" """
Convert app model config to agent chat app config Convert app model config to agent chat app config
:param app_model: app model :param app_model: app model
@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert( model=ModelConfigManager.convert(config=config_dict),
config=config_dict prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert( dataset=DatasetConfigManager.convert(config=config_dict),
config=config_dict agent=AgentConfigManager.convert(config=config_dict),
), additional_features=cls.convert_features(config_dict, app_mode),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config) config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# dataset configs # dataset configs
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
config) tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
config) tenant_id, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))
@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
:param config: app model config args :param config: app model config args
""" """
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = { config["agent_mode"] = {"enabled": False, "tools": []}
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config["agent_mode"].get("strategy"): if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [member.value for member in if config["agent_mode"]["strategy"] not in [
list(PlanningStrategy.__members__.values())]: member.value for member in list(PlanningStrategy.__members__.values())
]:
raise ValueError("strategy in agent_mode must be in the specified strategy list") raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"): if not config["agent_mode"].get("tools"):
@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset": if key == "dataset":
if 'id' not in tool_item: if "id" not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
try: try:

@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator): class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -39,7 +40,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -48,19 +50,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = False, stream: bool = False,
) -> dict | Generator[dict | str, None, None]: ... ) -> dict | Generator[dict | str, None, None]: ...
def generate(self, app_model: App, def generate(
user: Union[Account, EndUser], self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
args: Any, ) -> Union[dict, Generator[dict | str, None, None]]:
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[dict | str, None, None]]:
""" """
Generate App response. Generate App response.
@ -71,60 +71,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not stream: if not stream:
raise ValueError('Agent Chat App does not support blocking mode') raise ValueError("Agent Chat App does not support blocking mode")
if not args.get('query'): if not args.get("query"):
raise ValueError('query is required') raise ValueError("query is required")
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config # get app model config
app_model_config = self._get_app_model_config( app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get('model_config'): if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config') raise ValueError("Only in App debug mode can override model config")
# validate config # validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate( override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id, config=args.get("model_config")
config=args.get('model_config')
) )
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = { override_model_config_dict["retriever_resource"] = {"enabled": True}
"enabled": True
}
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
@ -133,7 +121,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, app_model=app_model,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation, conversation=conversation,
override_config_dict=override_model_config_dict override_config_dict=override_model_config_dict,
) )
# get tracing instance # get tracing instance
@ -154,14 +142,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
call_depth=0, call_depth=0,
trace_manager=trace_manager trace_manager=trace_manager,
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -170,17 +155,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'conversation_id': conversation.id, "application_generate_entity": application_generate_entity,
'message_id': message.id, "queue_manager": queue_manager,
}) "conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -194,13 +182,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return AgentChatAppGenerateResponseConverter.convert( return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker( def _generate_worker(
self, flask_app: Flask, self,
flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
@ -229,18 +215,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought from models.model import App, Conversation, Message, MessageAgentThought
@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
""" """
def run( def run(
self, application_generate_entity: AgentChatAppGenerateEntity, self,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
memory = None memory = None
@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
memory = TokenBufferMemory( memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory memory=memory,
) )
# moderation # moderation
@ -103,15 +101,15 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id message_id=message.id,
) )
except ModerationException as e: except ModerationError as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
message=message, message=message,
query=query, query=query,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from invoke_from=application_generate_entity.invoke_from,
) )
if annotation_reply: if annotation_reply:
queue_manager.publish( queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER PublishFrom.APPLICATION_MANAGER,
) )
self.direct_output( self.direct_output(
@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query query=query,
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory memory=memory,
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages prompt_messages=prompt_messages,
) )
if hosting_moderation_result: if hosting_moderation_result:
@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent agent_entity = app_config.agent
# load tool variables # load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, tool_conversation_variables = self._load_tool_variables(
user_id=application_generate_entity.user_id, conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
tenant_id=app_config.tenant_id) )
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
prompt_message, _ = self.organize_prompt_messages( prompt_message, _ = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables, db_variables=tool_conversation_variables,
model_instance=model_instance model_instance=model_instance,
) )
invoke_result = runner.run( invoke_result = runner.run(
@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
invoke_result=invoke_result, invoke_result=invoke_result,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=application_generate_entity.stream, stream=application_generate_entity.stream,
agent=True agent=True,
) )
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
""" """
load tool variables from database load tool variables from database
""" """
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( tool_variables: ToolConversationVariables = (
ToolConversationVariables.conversation_id == conversation_id, db.session.query(ToolConversationVariables)
ToolConversationVariables.tenant_id == tenant_id .filter(
).first() ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables: if tool_variables:
# save tool variables to session, so that we can update it later # save tool variables to session, so that we can update it later
@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
conversation_id=conversation_id, conversation_id=conversation_id,
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
variables_str='[]', variables_str="[]",
) )
db.session.add(tool_variables) db.session.add(tool_variables)
db.session.commit() db.session.commit()
return tool_variables return tool_variables
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
""" """
convert db variables to tool variables convert db variables to tool variables
""" """
return ToolRuntimeVariablePool(**{ return ToolRuntimeVariablePool(
'conversation_id': db_variables.conversation_id, **{
'user_id': db_variables.user_id, "conversation_id": db_variables.conversation_id,
'tenant_id': db_variables.tenant_id, "user_id": db_variables.user_id,
'pool': db_variables.variables "tenant_id": db_variables.tenant_id,
}) "pool": db_variables.variables,
}
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, )
message: Message) -> LLMUsage:
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
""" """
Get usage of all agent thoughts Get usage of all agent thoughts
:param model_config: model config :param model_config: model config
:param message: message :param message: message
:return: :return:
""" """
agent_thoughts = (db.session.query(MessageAgentThought) agent_thoughts = (
.filter(MessageAgentThought.message_id == message.id).all()) db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0 all_message_tokens = 0
all_answer_tokens = 0 all_answer_tokens = 0
@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage( return model_type_instance._calc_response_usage(
model_config.model, model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
model_config.credentials,
all_message_tokens,
all_answer_tokens
) )

@ -22,15 +22,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id, "conversation_id": blocking_response.data.conversation_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -44,8 +44,8 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@ -62,14 +62,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -92,20 +92,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

@ -13,11 +13,10 @@ class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse] _blocking_response_type: type[AppBlockingResponse]
@classmethod @classmethod
def convert(cls, response: Union[ def convert(
AppBlockingResponse, cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
Generator[AppStreamResponse, Any, None] ) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:
@ -52,8 +51,9 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ def convert_stream_simple_response(
-> Generator[str, None, None]: cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -64,24 +64,26 @@ class AppGenerateResponseConverter(ABC):
:return: :return:
""" """
# show_retrieve_source # show_retrieve_source
if 'retriever_resources' in metadata: if "retriever_resources" in metadata:
metadata['retriever_resources'] = [] metadata["retriever_resources"] = []
for resource in metadata['retriever_resources']: for resource in metadata["retriever_resources"]:
metadata['retriever_resources'].append({ metadata["retriever_resources"].append(
'segment_id': resource['segment_id'], {
'position': resource['position'], "segment_id": resource["segment_id"],
'document_name': resource['document_name'], "position": resource["position"],
'score': resource['score'], "document_name": resource["document_name"],
'content': resource['content'], "score": resource["score"],
}) "content": resource["content"],
}
)
# show annotation reply # show annotation reply
if 'annotation_reply' in metadata: if "annotation_reply" in metadata:
del metadata['annotation_reply'] del metadata["annotation_reply"]
# show usage # show usage
if 'usage' in metadata: if "usage" in metadata:
del metadata['usage'] del metadata["usage"]
return metadata return metadata
@ -93,16 +95,16 @@ class AppGenerateResponseConverter(ABC):
:return: :return:
""" """
error_responses = { error_responses = {
ValueError: {'code': 'invalid_param', 'status': 400}, ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: { QuotaExceededError: {
'code': 'provider_quota_exceeded', "code": "provider_quota_exceeded",
'message': "Your quota for Dify Hosted Model Provider has been exhausted. " "message": "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.", "Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400 "status": 400,
}, },
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
InvokeError: {'code': 'completion_request_error', 'status': 400} InvokeError: {"code": "completion_request_error", "status": 400},
} }
# Determine the response based on the type of exception # Determine the response based on the type of exception
@ -112,13 +114,13 @@ class AppGenerateResponseConverter(ABC):
data = v data = v
if data: if data:
data.setdefault('message', getattr(e, 'description', str(e))) data.setdefault("message", getattr(e, "description", str(e)))
else: else:
logging.error(e) logging.error(e)
data = { data = {
'code': 'internal_server_error', "code": "internal_server_error",
'message': 'Internal Server Error, please contact support.', "message": "Internal Server Error, please contact support.",
'status': 500 "status": 500,
} }
return data return data

@ -17,17 +17,17 @@ class BaseAppGenerator:
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable) user_input_value = inputs.get(var.variable)
if var.required and not user_input_value: if var.required and not user_input_value:
raise ValueError(f'{var.variable} is required in input form') raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value: if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None? # TODO: should we return None here if the default value is None?
return var.default or '' return var.default or ""
if ( if (
var.type var.type
in ( in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT, VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
) }
and user_input_value and user_input_value
and not isinstance(user_input_value, str) and not isinstance(user_input_value, str)
): ):
@ -35,7 +35,7 @@ class BaseAppGenerator:
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
try: try:
if '.' in user_input_value: if "." in user_input_value:
return float(user_input_value) return float(user_input_value)
else: else:
return int(user_input_value) return int(user_input_value)
@ -44,20 +44,20 @@ class BaseAppGenerator:
if var.type == VariableEntityType.SELECT: if var.type == VariableEntityType.SELECT:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f'{var.variable} in input form must be one of the following: {options}') raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
return user_input_value return user_input_value
def _sanitize_value(self, value: Any) -> Any: def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str): if isinstance(value, str):
return value.replace('\x00', '') return value.replace("\x00", "")
return value return value
@classmethod @classmethod
def convert_to_event_stream(cls, generator: Union[dict, Generator[dict| str, None, None]]): def convert_to_event_stream(cls, generator: Union[dict, Generator[dict | str, None, None]]):
""" """
Convert messages into event stream Convert messages into event stream
""" """

@ -24,9 +24,7 @@ class PublishFrom(Enum):
class AppQueueManager: class AppQueueManager:
def __init__(self, task_id: str, def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
user_id: str,
invoke_from: InvokeFrom) -> None:
if not user_id: if not user_id:
raise ValueError("user is required") raise ValueError("user is required")
@ -34,9 +32,10 @@ class AppQueueManager:
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, redis_client.setex(
f"{user_prefix}-{self._user_id}") AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue() q = queue.Queue()
@ -66,8 +65,7 @@ class AppQueueManager:
# publish two messages to make sure the client can receive the stop signal # publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed # and stop listening after the stop signal processed
self.publish( self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
PublishFrom.TASK_PIPELINE
) )
if elapsed_time // 10 > last_ping_time: if elapsed_time // 10 > last_ping_time:
@ -88,9 +86,7 @@ class AppQueueManager:
:param pub_from: publish from :param pub_from: publish from
:return: :return:
""" """
self.publish(QueueErrorEvent( self.publish(QueueErrorEvent(error=e), pub_from)
error=e
), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
""" """
@ -122,8 +118,8 @@ class AppQueueManager:
if result is None: if result is None:
return return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode('utf-8') != f"{user_prefix}-{user_id}": if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return return
stopped_cache_key = cls._generate_stopped_cache_key(task_id) stopped_cache_key = cls._generate_stopped_cache_key(task_id)
@ -168,10 +164,12 @@ class AppQueueManager:
for item in data: for item in data:
self._check_for_sqlalchemy_models(item) self._check_for_sqlalchemy_models(item)
else: else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances " raise TypeError(
"that cause thread safety issues is not allowed.") "Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
)
class GenerateTaskStoppedException(Exception): class GenerateTaskStoppedError(Exception):
pass pass

@ -31,12 +31,15 @@ if TYPE_CHECKING:
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App, def get_pre_calculate_rest_tokens(
model_config: ModelConfigWithCredentialsEntity, self,
prompt_template_entity: PromptTemplateEntity, app_record: App,
inputs: dict[str, str], model_config: ModelConfigWithCredentialsEntity,
files: list["FileVar"], prompt_template_entity: PromptTemplateEntity,
query: Optional[str] = None) -> int: inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
) -> int:
""" """
Get pre calculate rest tokens Get pre calculate rest tokens
:param app_record: app record :param app_record: app record
@ -49,18 +52,20 @@ class AppRunner:
""" """
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
model=model_config.model
) )
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0 max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules: for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens' if parameter_rule.name == "max_tokens" or (
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
max_tokens = (model_config.parameters.get(parameter_rule.name) ):
or model_config.parameters.get(parameter_rule.use_template)) or 0 max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None: if model_context_tokens is None:
return -1 return -1
@ -75,36 +80,39 @@ class AppRunner:
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
prompt_tokens = model_instance.get_llm_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
prompt_messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0: if rest_tokens < 0:
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " raise InvokeBadRequestError(
"or shrink the max token, or switch to a llm with a larger token limit size.") "Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens return rest_tokens
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, def recalc_llm_max_tokens(
prompt_messages: list[PromptMessage]): self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
model=model_config.model
) )
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0 max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules: for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens' if parameter_rule.name == "max_tokens" or (
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
max_tokens = (model_config.parameters.get(parameter_rule.name) ):
or model_config.parameters.get(parameter_rule.use_template)) or 0 max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None: if model_context_tokens is None:
return -1 return -1
@ -112,27 +120,28 @@ class AppRunner:
if max_tokens is None: if max_tokens is None:
max_tokens = 0 max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
prompt_messages
)
if prompt_tokens + max_tokens > model_context_tokens: if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16) max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules: for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens' if parameter_rule.name == "max_tokens" or (
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App, def organize_prompt_messages(
model_config: ModelConfigWithCredentialsEntity, self,
prompt_template_entity: PromptTemplateEntity, app_record: App,
inputs: dict[str, str], model_config: ModelConfigWithCredentialsEntity,
files: list["FileVar"], prompt_template_entity: PromptTemplateEntity,
query: Optional[str] = None, inputs: dict[str, str],
context: Optional[str] = None, files: list["FileVar"],
memory: Optional[TokenBufferMemory] = None) \ query: Optional[str] = None,
-> tuple[list[PromptMessage], Optional[list[str]]]: context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
""" """
Organize prompt messages Organize prompt messages
:param context: :param context:
@ -152,60 +161,54 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode), app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query or "",
files=files, files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config model_config=model_config,
) )
else: else:
memory_config = MemoryConfig( memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
window=MemoryConfig.WindowConfig(
enabled=False
)
)
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION: if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
prompt_template = CompletionModelPromptTemplate( prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
text=advanced_completion_prompt_template.prompt
)
if advanced_completion_prompt_template.role_prefix: if advanced_completion_prompt_template.role_prefix:
memory_config.role_prefix = MemoryConfig.RolePrefix( memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user, user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant assistant=advanced_completion_prompt_template.role_prefix.assistant,
) )
else: else:
prompt_template = [] prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages: for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage( prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
text=message.text,
role=message.role
))
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query or "",
files=files, files=files,
context=context, context=context,
memory_config=memory_config, memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config model_config=model_config,
) )
stop = model_config.stop stop = model_config.stop
return prompt_messages, stop return prompt_messages, stop
def direct_output(self, queue_manager: AppQueueManager, def direct_output(
app_generate_entity: EasyUIBasedAppGenerateEntity, self,
prompt_messages: list, queue_manager: AppQueueManager,
text: str, app_generate_entity: EasyUIBasedAppGenerateEntity,
stream: bool, prompt_messages: list,
usage: Optional[LLMUsage] = None) -> None: text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
""" """
Direct output Direct output
:param queue_manager: application queue manager :param queue_manager: application queue manager
@ -222,17 +225,10 @@ class AppRunner:
chunk = LLMResultChunk( chunk = LLMResultChunk(
model=app_generate_entity.model_conf.model, model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
index=index,
message=AssistantPromptMessage(content=token)
)
) )
queue_manager.publish( queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
QueueLLMChunkEvent(
chunk=chunk
), PublishFrom.APPLICATION_MANAGER
)
index += 1 index += 1
time.sleep(0.01) time.sleep(0.01)
@ -242,15 +238,19 @@ class AppRunner:
model=app_generate_entity.model_conf.model, model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text), message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage() usage=usage or LLMUsage.empty_usage(),
), ),
), PublishFrom.APPLICATION_MANAGER ),
PublishFrom.APPLICATION_MANAGER,
) )
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], def _handle_invoke_result(
queue_manager: AppQueueManager, self,
stream: bool, invoke_result: Union[LLMResult, Generator],
agent: bool = False) -> None: queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@ -260,21 +260,13 @@ class AppRunner:
:return: :return:
""" """
if not stream: if not stream:
self._handle_invoke_result_direct( self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
else: else:
self._handle_invoke_result_stream( self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
def _handle_invoke_result_direct(self, invoke_result: LLMResult, def _handle_invoke_result_direct(
queue_manager: AppQueueManager, self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
agent: bool) -> None: ) -> None:
""" """
Handle invoke result direct Handle invoke result direct
:param invoke_result: invoke result :param invoke_result: invoke result
@ -285,12 +277,13 @@ class AppRunner:
queue_manager.publish( queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
llm_result=invoke_result, llm_result=invoke_result,
), PublishFrom.APPLICATION_MANAGER ),
PublishFrom.APPLICATION_MANAGER,
) )
def _handle_invoke_result_stream(self, invoke_result: Generator, def _handle_invoke_result_stream(
queue_manager: AppQueueManager, self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
agent: bool) -> None: ) -> None:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@ -300,21 +293,13 @@ class AppRunner:
""" """
model = None model = None
prompt_messages = [] prompt_messages = []
text = '' text = ""
usage = None usage = None
for result in invoke_result: for result in invoke_result:
if not agent: if not agent:
queue_manager.publish( queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
QueueLLMChunkEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
else: else:
queue_manager.publish( queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
QueueAgentMessageEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
text += result.delta.message.content text += result.delta.message.content
@ -331,25 +316,24 @@ class AppRunner:
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
llm_result = LLMResult( llm_result = LLMResult(
model=model, model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
) )
queue_manager.publish( queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
llm_result=llm_result, llm_result=llm_result,
), PublishFrom.APPLICATION_MANAGER ),
PublishFrom.APPLICATION_MANAGER,
) )
def moderation_for_inputs( def moderation_for_inputs(
self, app_id: str, self,
tenant_id: str, app_id: str,
app_generate_entity: AppGenerateEntity, tenant_id: str,
inputs: Mapping[str, Any], app_generate_entity: AppGenerateEntity,
query: str, inputs: Mapping[str, Any],
message_id: str, query: str,
message_id: str,
) -> tuple[bool, dict, str]: ) -> tuple[bool, dict, str]:
""" """
Process sensitive_word_avoidance. Process sensitive_word_avoidance.
@ -367,14 +351,17 @@ class AppRunner:
tenant_id=tenant_id, tenant_id=tenant_id,
app_config=app_generate_entity.app_config, app_config=app_generate_entity.app_config,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query or "",
message_id=message_id, message_id=message_id,
trace_manager=app_generate_entity.trace_manager trace_manager=app_generate_entity.trace_manager,
) )
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def check_hosting_moderation(
queue_manager: AppQueueManager, self,
prompt_messages: list[PromptMessage]) -> bool: application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage],
) -> bool:
""" """
Check hosting moderation Check hosting moderation
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -384,8 +371,7 @@ class AppRunner:
""" """
hosting_moderation_feature = HostingModerationFeature() hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check( moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
prompt_messages=prompt_messages
) )
if moderation_result: if moderation_result:
@ -393,18 +379,20 @@ class AppRunner:
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text="I apologize for any confusion, " \ text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
"but I'm an AI assistant to be helpful, harmless, and honest.", stream=application_generate_entity.stream,
stream=application_generate_entity.stream
) )
return moderation_result return moderation_result
def fill_in_inputs_from_external_data_tools(self, tenant_id: str, def fill_in_inputs_from_external_data_tools(
app_id: str, self,
external_data_tools: list[ExternalDataVariableEntity], tenant_id: str,
inputs: dict, app_id: str,
query: str) -> dict: external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
""" """
Fill in variable inputs from external data tools if exists. Fill in variable inputs from external data tools if exists.
@ -417,18 +405,12 @@ class AppRunner:
""" """
external_data_fetch_feature = ExternalDataFetch() external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch( return external_data_fetch_feature.fetch(
tenant_id=tenant_id, tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
) )
def query_app_annotations_to_reply(self, app_record: App, def query_app_annotations_to_reply(
message: Message, self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
query: str, ) -> Optional[MessageAnnotation]:
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
""" """
Query app annotations to reply Query app annotations to reply
:param app_record: app record :param app_record: app record
@ -440,9 +422,5 @@ class AppRunner:
""" """
annotation_reply_feature = AnnotationReplyFeature() annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query( return annotation_reply_feature.query(
app_record=app_record, app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
) )

@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
""" """
Chatbot App Config Entity. Chatbot App Config Entity.
""" """
pass pass
class ChatAppConfigManager(BaseAppConfigManager): class ChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(
app_model_config: AppModelConfig, cls,
conversation: Optional[Conversation] = None, app_model: App,
override_config_dict: Optional[dict] = None) -> ChatAppConfig: app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> ChatAppConfig:
""" """
Convert app model config to chat app config Convert app model config to chat app config
:param app_model: app model :param app_model: app model
@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
config_dict = app_model_config_dict.copy() config_dict = app_model_config_dict.copy()
else: else:
if not override_config_dict: if not override_config_dict:
raise Exception('override_config_dict is required when config_from is ARGS') raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict config_dict = override_config_dict
@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert( model=ModelConfigManager.convert(config=config_dict),
config=config_dict prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert( dataset=DatasetConfigManager.convert(config=config_dict),
config=config_dict additional_features=cls.convert_features(config_dict, app_mode),
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
config) tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# opening_statement # opening_statement
@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config) config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
config) tenant_id, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator): class ChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Any,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Any,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -56,7 +58,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) -> Union[dict, Generator[dict | str, None, None]]: ... ) -> Union[dict, Generator[dict | str, None, None]]: ...
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Any, args: Any,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -71,58 +74,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not args.get('query'): if not args.get("query"):
raise ValueError('query is required') raise ValueError("query is required")
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config # get app model config
app_model_config = self._get_app_model_config( app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get('model_config'): if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config') raise ValueError("Only in App debug mode can override model config")
# validate config # validate config
override_model_config_dict = ChatAppConfigManager.config_validate( override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id, config=args.get("model_config")
config=args.get('model_config')
) )
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = { override_model_config_dict["retriever_resource"] = {"enabled": True}
"enabled": True
}
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
@ -131,7 +122,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, app_model=app_model,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation, conversation=conversation,
override_config_dict=override_model_config_dict override_config_dict=override_model_config_dict,
) )
# get tracing instance # get tracing instance
@ -150,14 +141,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager,
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -166,17 +154,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'conversation_id': conversation.id, "application_generate_entity": application_generate_entity,
'message_id': message.id, "queue_manager": queue_manager,
}) "conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -190,16 +181,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return ChatAppGenerateResponseConverter.convert( return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: ChatAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
conversation_id: str, application_generate_entity: ChatAppGenerateEntity,
message_id: str) -> None: queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -221,20 +212,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation, conversation=conversation,
message=message message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message from models.model import App, Conversation, Message
@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
Chat Application Runner Chat Application Runner
""" """
def run(self, application_generate_entity: ChatAppGenerateEntity, def run(
queue_manager: AppQueueManager, self,
conversation: Conversation, application_generate_entity: ChatAppGenerateEntity,
message: Message) -> None: queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
memory = None memory = None
@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
memory = TokenBufferMemory( memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory memory=memory,
) )
# moderation # moderation
@ -96,15 +96,15 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id message_id=message.id,
) )
except ModerationException as e: except ModerationError as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
message=message, message=message,
query=query, query=query,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from invoke_from=application_generate_entity.invoke_from,
) )
if annotation_reply: if annotation_reply:
queue_manager.publish( queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER PublishFrom.APPLICATION_MANAGER,
) )
self.direct_output( self.direct_output(
@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query query=query,
) )
# get context from datasets # get context from datasets
@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
app_record.id, app_record.id,
message.id, message.id,
application_generate_entity.user_id, application_generate_entity.user_id,
application_generate_entity.invoke_from application_generate_entity.invoke_from,
) )
dataset_retrieval = DatasetRetrieval(application_generate_entity) dataset_retrieval = DatasetRetrieval(application_generate_entity)
@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
files=files, files=files,
query=query, query=query,
context=context, context=context,
memory=memory memory=memory,
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages prompt_messages=prompt_messages,
) )
if hosting_moderation_result: if hosting_moderation_result:
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens( self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
db.session.close() db.session.close()
@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
queue_manager=queue_manager,
stream=application_generate_entity.stream
) )

@ -22,15 +22,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id, "conversation_id": blocking_response.data.conversation_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -44,8 +44,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@ -62,14 +62,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -92,20 +92,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'conversation_id': chunk.conversation_id, "conversation_id": chunk.conversation_id,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
""" """
Completion App Config Entity. Completion App Config Entity.
""" """
pass pass
class CompletionAppConfigManager(BaseAppConfigManager): class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, def get_app_config(
app_model_config: AppModelConfig, cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
override_config_dict: Optional[dict] = None) -> CompletionAppConfig: ) -> CompletionAppConfig:
""" """
Convert app model config to completion app config Convert app model config to completion app config
:param app_model: app model :param app_model: app model
@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert( model=ModelConfigManager.convert(config=config_dict),
config=config_dict prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert( dataset=DatasetConfigManager.convert(config=config_dict),
config=config_dict additional_features=cls.convert_features(config_dict, app_mode),
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
config) tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# text_to_speech # text_to_speech
@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
config) tenant_id, config
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator): class CompletionAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -41,7 +42,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -72,12 +74,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
query = args['query'] query = args["query"]
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError('query must be a string') raise ValueError("query must be a string")
query = query.replace('\x00', '') query = query.replace("\x00", "")
inputs = args['inputs'] inputs = args["inputs"]
extras = {} extras = {}
@ -85,41 +87,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
conversation = None conversation = None
# get app model config # get app model config
app_model_config = self._get_app_model_config( app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get('model_config'): if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config') raise ValueError("Only in App debug mode can override model config")
# validate config # validate config
override_model_config_dict = CompletionAppConfigManager.config_validate( override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id, config=args.get("model_config")
config=args.get('model_config')
) )
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = CompletionAppConfigManager.get_app_config( app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
) )
# get tracing instance # get tracing instance
@ -137,14 +129,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager,
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity)
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -153,16 +142,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'message_id': message.id, "application_generate_entity": application_generate_entity,
}) "queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -176,15 +168,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return CompletionAppGenerateResponseConverter.convert( return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: CompletionAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
message_id: str) -> None: application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -203,20 +195,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
runner.run( runner.run(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
message=message message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
@ -225,12 +216,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
finally: finally:
db.session.close() db.session.close()
def generate_more_like_this(self, app_model: App, def generate_more_like_this(
message_id: str, self,
user: Union[Account, EndUser], app_model: App,
invoke_from: InvokeFrom, message_id: str,
stream: bool = True) \ user: Union[Account, EndUser],
-> Union[dict, Generator[str, None, None]]: invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
""" """
Generate App response. Generate App response.
@ -240,13 +233,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
message = db.session.query(Message).filter( message = (
Message.id == message_id, db.session.query(Message)
Message.app_id == app_model.id, .filter(
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), Message.id == message_id,
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Message.app_id == app_model.id,
Message.from_account_id == (user.id if isinstance(user, Account) else None), Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
).first() Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message: if not message:
raise MessageNotExistsError() raise MessageNotExistsError()
@ -259,29 +256,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_model_config = message.app_model_config app_model_config = message.app_model_config
override_model_config_dict = app_model_config.to_dict() override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict['model'] model_dict = override_model_config_dict["model"]
completion_params = model_dict.get('completion_params') completion_params = model_dict.get("completion_params")
completion_params['temperature'] = 0.9 completion_params["temperature"] = 0.9
model_dict['completion_params'] = completion_params model_dict["completion_params"] = completion_params
override_model_config_dict['model'] = model_dict override_model_config_dict["model"] = model_dict
# parse files # parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
message.files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = CompletionAppConfigManager.get_app_config( app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
) )
# init application generate entity # init application generate entity
@ -295,14 +286,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras={} extras={},
) )
# init generate records # init generate records
( (conversation, message) = self._init_generate_records(application_generate_entity)
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -311,16 +299,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id message_id=message.id,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(),
'message_id': message.id, "application_generate_entity": application_generate_entity,
}) "queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -334,7 +325,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return CompletionAppGenerateResponseConverter.convert( return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)

@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Message from models.model import App, Message
@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner Completion Application Runner
""" """
def run(self, application_generate_entity: CompletionAppGenerateEntity, def run(
queue_manager: AppQueueManager, self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
message: Message) -> None: ) -> None:
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query,
) )
# moderation # moderation
@ -77,15 +77,15 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id message_id=message.id,
) )
except ModerationException as e: except ModerationError as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream,
) )
return return
@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query query=query,
) )
# get context from datasets # get context from datasets
@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
app_record.id, app_record.id,
message.id, message.id,
application_generate_entity.user_id, application_generate_entity.user_id,
application_generate_entity.invoke_from application_generate_entity.invoke_from,
) )
dataset_config = app_config.dataset dataset_config = app_config.dataset
@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback, hit_callback=hit_callback,
message_id=message.id message_id=message.id,
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
context=context context=context,
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages prompt_messages=prompt_messages,
) )
if hosting_moderation_result: if hosting_moderation_result:
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens( self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model model=application_generate_entity.model_conf.model,
) )
db.session.close() db.session.close()
@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
queue_manager=queue_manager,
stream=application_generate_entity.stream
) )

@ -22,14 +22,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
'event': 'message', "event": "message",
'task_id': blocking_response.task_id, "task_id": blocking_response.task_id,
'id': blocking_response.data.id, "id": blocking_response.data.id,
'message_id': blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
'mode': blocking_response.data.mode, "mode": blocking_response.data.mode,
'answer': blocking_response.data.answer, "answer": blocking_response.data.answer,
'metadata': blocking_response.data.metadata, "metadata": blocking_response.data.metadata,
'created_at': blocking_response.data.created_at "created_at": blocking_response.data.created_at,
} }
return response return response
@ -43,8 +43,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {}) metadata = response.get("metadata", {})
response['metadata'] = cls._get_simple_metadata(metadata) response["metadata"] = cls._get_simple_metadata(metadata)
return response return response
@ -61,13 +61,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -90,19 +90,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'message_id': chunk.message_id, "message_id": chunk.message_id,
'created_at': chunk.created_at "created_at": chunk.created_at,
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

@ -8,7 +8,7 @@ from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,
@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator): class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response( def _handle_response(
self, application_generate_entity: Union[ self,
ChatAppGenerateEntity, application_generate_entity: Union[
CompletionAppGenerateEntity, ChatAppGenerateEntity,
AgentChatAppGenerateEntity, CompletionAppGenerateEntity,
AdvancedChatAppGenerateEntity AgentChatAppGenerateEntity,
], AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, ],
conversation: Conversation, queue_manager: AppQueueManager,
message: Message, conversation: Conversation,
user: Union[Account, EndUser], message: Message,
stream: bool = False, user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ ) -> Union[
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]: ]:
""" """
Handle response. Handle response.
@ -70,24 +70,25 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user, user=user,
stream=stream stream=stream,
) )
try: try:
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e
def _get_conversation_by_user(self, app_model: App, conversation_id: str, def _get_conversation_by_user(
user: Union[Account, EndUser]) -> Conversation: self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
) -> Conversation:
conversation_filter = [ conversation_filter = [
Conversation.id == conversation_id, Conversation.id == conversation_id,
Conversation.app_id == app_model.id, Conversation.app_id == app_model.id,
Conversation.status == 'normal' Conversation.status == "normal",
] ]
if isinstance(user, Account): if isinstance(user, Account):
@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if not conversation: if not conversation:
raise ConversationNotExistsError() raise ConversationNotExistsError()
if conversation.status != 'normal': if conversation.status != "normal":
raise ConversationCompletedError() raise ConversationCompletedError()
return conversation return conversation
def _get_app_model_config(self, app_model: App, def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
conversation: Optional[Conversation] = None) \
-> AppModelConfig:
if conversation: if conversation:
app_model_config = db.session.query(AppModelConfig).filter( app_model_config = (
AppModelConfig.id == conversation.app_model_config_id, db.session.query(AppModelConfig)
AppModelConfig.app_id == app_model.id .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
).first() .first()
)
if not app_model_config: if not app_model_config:
raise AppModelConfigBrokenError() raise AppModelConfigBrokenError()
@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return app_model_config return app_model_config
def _init_generate_records(self, def _init_generate_records(
application_generate_entity: Union[ self,
ChatAppGenerateEntity, application_generate_entity: Union[
CompletionAppGenerateEntity, ChatAppGenerateEntity,
AgentChatAppGenerateEntity, CompletionAppGenerateEntity,
AdvancedChatAppGenerateEntity AgentChatAppGenerateEntity,
], AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None) \ ],
-> tuple[Conversation, Message]: conversation: Optional[Conversation] = None,
) -> tuple[Conversation, Message]:
""" """
Initialize generate records Initialize generate records
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -147,11 +148,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
# get from source # get from source
end_user_id = None end_user_id = None
account_id = None account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = 'api' from_source = "api"
end_user_id = application_generate_entity.user_id end_user_id = application_generate_entity.user_id
else: else:
from_source = 'console' from_source = "console"
account_id = application_generate_entity.user_id account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model model_id = application_generate_entity.model_conf.model
override_model_configs = None override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
}:
override_model_configs = app_config.app_model_config_dict override_model_configs = app_config.app_model_config_dict
# get conversation introduction # get conversation introduction
@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_id=model_id, model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value, mode=app_config.app_mode.value,
name='New conversation', name="New conversation",
inputs=application_generate_entity.inputs, inputs=application_generate_entity.inputs,
introduction=introduction, introduction=introduction,
system_instruction="", system_instruction="",
system_instruction_tokens=0, system_instruction_tokens=0,
status='normal', status="normal",
invoke_from=application_generate_entity.invoke_from.value, invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source, from_source=from_source,
from_end_user_id=end_user_id, from_end_user_id=end_user_id,
@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
answer_price_unit=0, answer_price_unit=0,
provider_response_latency=0, provider_response_latency=0,
total_price=0, total_price=0,
currency='USD', currency="USD",
invoke_from=application_generate_entity.invoke_from.value, invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source, from_source=from_source,
from_end_user_id=end_user_id, from_end_user_id=end_user_id,
from_account_id=account_id from_account_id=account_id,
) )
db.session.add(message) db.session.add(message)
@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id, message_id=message.id,
type=file.type.value, type=file.type.value,
transfer_method=file.transfer_method.value, transfer_method=file.transfer_method.value,
belongs_to='user', belongs_to="user",
url=file.url, url=file.url,
upload_file_id=file.related_id, upload_file_id=file.related_id,
created_by_role=('account' if account_id else 'end_user'), created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id, created_by=account_id or end_user_id,
) )
db.session.add(message_file) db.session.add(message_file)
@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id :param conversation_id: conversation id
:return: conversation :return: conversation
""" """
conversation = ( conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation: if not conversation:
raise ConversationNotExistsError() raise ConversationNotExistsError()
@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id :param message_id: message id
:return: message :return: message
""" """
message = ( message = db.session.query(Message).filter(Message.id == message_id).first()
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
return message return message

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager): class MessageBasedAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, def __init__(
user_id: str, self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
invoke_from: InvokeFrom, ) -> None:
conversation_id: str,
app_mode: str,
message_id: str) -> None:
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id) self._conversation_id = str(conversation_id)
@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id, message_id=self._message_id,
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
app_mode=self._app_mode, app_mode=self._app_mode,
event=event event=event,
) )
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id, message_id=self._message_id,
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
app_mode=self._app_mode, app_mode=self._app_mode,
event=event event=event,
) )
self._q.put(message) self._q.put(message)
if isinstance(event, QueueStopEvent if isinstance(
| QueueErrorEvent event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
| QueueMessageEndEvent ):
| QueueAdvancedChatMessageEndEvent):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()

@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
""" """
Workflow App Config Entity. Workflow App Config Entity.
""" """
pass pass
@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
app_id=app_model.id, app_id=app_model.id,
app_mode=app_mode, app_mode=app_mode,
workflow_id=workflow.id, workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
config=features_dict variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
), additional_features=cls.convert_features(features_dict, app_mode),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
) )
return app_config return app_config
@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, config=config, is_vision=False
is_vision=False
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
config=config,
only_structure_validate=only_structure_validate
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)

@ -12,7 +12,7 @@ from pydantic import ValidationError
import contexts import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.app_runner import WorkflowAppRunner
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator): class WorkflowAppGenerator(BaseAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
@ -46,14 +47,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
@overload @overload
def generate( def generate(
self, app_model: App, self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, stream: Literal[False] = False,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
) -> dict: ... ) -> dict: ...
@overload @overload
@ -76,7 +78,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
): ):
""" """
Generate App response. Generate App response.
@ -90,26 +92,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param call_depth: call depth :param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
""" """
inputs = args['inputs'] inputs = args["inputs"]
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg( file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = WorkflowAppConfigManager.get_app_config( app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id user_id = user.id if isinstance(user, Account) else user.session_id
@ -125,7 +120,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
call_depth=call_depth, call_depth=call_depth,
trace_manager=trace_manager trace_manager=trace_manager,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -136,11 +131,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
invoke_from=invoke_from, invoke_from=invoke_from,
stream=stream, stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id workflow_thread_pool_id=workflow_thread_pool_id,
) )
def _generate( def _generate(
self, *, self,
*,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
@ -165,17 +161,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode app_mode=app_model.mode,
) )
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(
'flask_app': current_app._get_current_object(), # type: ignore target=self._generate_worker,
'application_generate_entity': application_generate_entity, kwargs={
'queue_manager': queue_manager, "flask_app": current_app._get_current_object(), # type: ignore
'context': contextvars.copy_context(), "application_generate_entity": application_generate_entity,
'workflow_thread_pool_id': workflow_thread_pool_id "queue_manager": queue_manager,
}) "context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start() worker_thread.start()
@ -188,10 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream, stream=stream,
) )
return WorkflowAppGenerateResponseConverter.convert( return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
response=response,
invoke_from=invoke_from
)
def single_iteration_generate(self, app_model: App, def single_iteration_generate(self, app_model: App,
workflow: Workflow, workflow: Workflow,
@ -210,16 +206,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not node_id: if not node_id:
raise ValueError('node_id is required') raise ValueError("node_id is required")
if args.get('inputs') is None: if args.get("inputs") is None:
raise ValueError('inputs is required') raise ValueError("inputs is required")
# convert to app config # convert to app config
app_config = WorkflowAppConfigManager.get_app_config( app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_model=app_model,
workflow=workflow
)
# init application generate entity # init application generate entity
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = WorkflowAppGenerateEntity(
@ -230,13 +223,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
extras={ extras={"auto_generate_conversation_name": False},
"auto_generate_conversation_name": False
},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, node_id=node_id, inputs=args["inputs"]
inputs=args['inputs'] ),
)
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -246,14 +236,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
stream=stream stream=stream,
) )
def _generate_worker(self, flask_app: Flask, def _generate_worker(
application_generate_entity: WorkflowAppGenerateEntity, self,
queue_manager: AppQueueManager, flask_app: Flask,
context: contextvars.Context, application_generate_entity: WorkflowAppGenerateEntity,
workflow_thread_pool_id: Optional[str] = None) -> None: queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -270,22 +263,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner( runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id workflow_thread_pool_id=workflow_thread_pool_id,
) )
runner.run() runner.run()
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'), InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
@ -294,14 +286,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
finally: finally:
db.session.close() db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, def _handle_response(
workflow: Workflow, self,
queue_manager: AppQueueManager, application_generate_entity: WorkflowAppGenerateEntity,
user: Union[Account, EndUser], workflow: Workflow,
stream: bool = False) -> Union[ queue_manager: AppQueueManager,
WorkflowAppBlockingResponse, user: Union[Account, EndUser],
Generator[WorkflowAppStreamResponse, None, None] stream: bool = False,
]: ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -317,14 +309,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow, workflow=workflow,
queue_manager=queue_manager, queue_manager=queue_manager,
user=user, user=user,
stream=stream stream=stream,
) )
try: try:
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager): class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
user_id: str,
invoke_from: InvokeFrom,
app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode self._app_mode = app_mode
@ -27,20 +24,19 @@ class WorkflowAppQueueManager(AppQueueManager):
:param pub_from: :param pub_from:
:return: :return:
""" """
message = WorkflowQueueMessage( message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
task_id=self._task_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message) self._q.put(message)
if isinstance(event, QueueStopEvent if isinstance(
| QueueErrorEvent event,
| QueueMessageEndEvent QueueStopEvent
| QueueWorkflowSucceededEvent | QueueErrorEvent
| QueueWorkflowFailedEvent): | QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()

@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
""" """
def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user: if end_user:
user_id = end_user.session_id user_id = end_user.session_id
@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError('App not found') raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow: if not workflow:
raise ValueError('Workflow not initialized') raise ValueError("Workflow not initialized")
db.session.close() db.session.close()
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested # if only single iteration run is requested
@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs user_inputs=self.application_generate_entity.single_iteration_run.inputs,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files files = self.application_generate_entity.files
@ -114,18 +113,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER else UserFrom.END_USER
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool, variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id thread_pool_id=self.workflow_thread_pool_id,
) )
generator = workflow_entry.run( generator = workflow_entry.run(callbacks=workflow_callbacks)
callbacks=workflow_callbacks
)
for event in generator: for event in generator:
self._handle_event(workflow_entry, event) self._handle_event(workflow_entry, event)

@ -46,12 +46,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id, "workflow_run_id": chunk.workflow_run_id,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -74,12 +74,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping' yield "ping"
continue continue
response_chunk = { response_chunk = {
'event': sub_stream_response.event.value, "event": sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id, "workflow_run_id": chunk.workflow_run_id,
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):

@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity _application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, def __init__(
workflow: Workflow, self,
queue_manager: AppQueueManager, application_generate_entity: WorkflowAppGenerateEntity,
user: Union[Account, EndUser], workflow: Workflow,
stream: bool) -> None: queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
""" """
Initialize GenerateTaskPipeline. Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow self._workflow = workflow
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id SystemVariableKey.USER_ID: user_id,
} }
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user) db.session.refresh(self._user)
db.session.close() db.session.close()
generator = self._wrapper_process_stream_response( generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
-> WorkflowAppBlockingResponse:
""" """
To blocking response. To blocking response.
:return: :return:
@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
total_tokens=stream_response.data.total_tokens, total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps, total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at), created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at) finished_at=int(stream_response.data.finished_at),
) ),
) )
return response return response
else: else:
continue continue
raise Exception('Queue listening stopped unexpectedly.') raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_stream_response(
-> Generator[WorkflowAppStreamResponse, None, None]: self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -158,34 +160,34 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, WorkflowStartStreamResponse): if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse( yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
workflow_run_id=workflow_run_id,
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio() audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ def _wrapper_process_stream_response(
Generator[StreamResponse, None, None]: self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ if (
'text_to_speech'].get('autoPlay') == 'enabled': features_dict.get("text_to_speech")
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -197,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
try: try:
if not tts_publisher: if not tts_publisher:
break break
audio_trunk = tts_publisher.checkAndGetAudio() audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None: if audio_trunk is None:
# release cpu # release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
break break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None, tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start() workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response( yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start( workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
workflow_run=workflow_run,
event=event
)
response = self._workflow_node_start_to_stream_response( response = self._workflow_node_start_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
response = self._workflow_node_finish_to_stream_response( response = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution workflow_node_execution=workflow_node_execution,
) )
if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response( yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response( yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response( yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationNextEvent): elif isinstance(event, QueueIterationNextEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response( yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueIterationCompletedEvent): elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response( yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
workflow_run=workflow_run,
event=event
) )
elif isinstance(event, QueueWorkflowSucceededEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success( workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run, workflow_run=workflow_run,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception("Workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.') raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run, workflow_run=workflow_run,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_run=workflow_run
) )
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
@ -387,14 +376,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
tts_publisher.publish(message=queue_message) tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text) yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
else: else:
continue continue
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
""" """
Save workflow app log. Save workflow app log.
@ -417,14 +407,16 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log) db.session.add(workflow_app_log)
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
) -> TextChunkStreamResponse:
""" """
Handle completed event. Handle completed event.
:param text: text :param text: text
@ -432,7 +424,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
response = TextChunkStreamResponse( response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text) data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
) )
return response return response

@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner):
""" """
Init graph Init graph
""" """
if 'nodes' not in graph_config or 'edges' not in graph_config: if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get('nodes'), list): if not isinstance(graph_config.get("nodes"), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get('edges'), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError("edges in workflow graph must be a list")
# init graph # init graph
graph = Graph.init( graph = Graph.init(graph_config=graph_config)
graph_config=graph_config
)
if not graph: if not graph:
raise ValueError('graph not found in workflow') raise ValueError("graph not found in workflow")
return graph return graph
def _get_graph_and_variable_pool_of_single_iteration( def _get_graph_and_variable_pool_of_single_iteration(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict,
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single iteration Get variable pool of single iteration
""" """
# fetch workflow graph # fetch workflow graph
graph_config = workflow.graph_dict graph_config = workflow.graph_dict
if not graph_config: if not graph_config:
raise ValueError('workflow graph not found') raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config) graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config: if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get('nodes'), list): if not isinstance(graph_config.get("nodes"), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get('edges'), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration # filter nodes only in iteration
node_configs = [ node_configs = [
node for node in graph_config.get('nodes', []) node
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
] ]
graph_config['nodes'] = node_configs graph_config["nodes"] = node_configs
node_ids = [node.get('id') for node in node_configs] node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration # filter edges only in iteration
edge_configs = [ edge_configs = [
edge for edge in graph_config.get('edges', []) edge
if (edge.get('source') is None or edge.get('source') in node_ids) for edge in graph_config.get("edges", [])
and (edge.get('target') is None or edge.get('target') in node_ids) if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
] ]
graph_config['edges'] = edge_configs graph_config["edges"] = edge_configs
# init graph # init graph
graph = Graph.init( graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
graph_config=graph_config,
root_node_id=node_id
)
if not graph: if not graph:
raise ValueError('graph not found in workflow') raise ValueError("graph not found in workflow")
# fetch node config from node id # fetch node config from node id
iteration_node_config = None iteration_node_config = None
for node in node_configs: for node in node_configs:
if node.get('id') == node_id: if node.get("id") == node_id:
iteration_node_config = node iteration_node_config = node
break break
if not iteration_node_config: if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph') raise ValueError("iteration node id not found in workflow graph")
# Get node class # Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type) node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls) node_cls = cast(type[BaseNode], node_cls)
@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner):
try: try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, graph_config=workflow.graph_dict, config=iteration_node_config
config=iteration_node_config
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner):
variable_pool=variable_pool, variable_pool=variable_pool,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
node_type=node_type, node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {})) node_data=IterationNodeData(**iteration_node_config.get("data", {})),
) )
return graph, variable_pool return graph, variable_pool
@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner):
""" """
if isinstance(event, GraphRunStartedEvent): if isinstance(event, GraphRunStartedEvent):
self._publish_event( self._publish_event(
QueueWorkflowStartedEvent( QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
) )
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, GraphRunSucceededEvent):
self._publish_event( self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self._publish_event( self._publish_event(QueueWorkflowFailedEvent(error=event.error))
QueueWorkflowFailedEvent(error=event.error)
)
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self._publish_event( self._publish_event(
QueueNodeStartedEvent( QueueNodeStartedEvent(
@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner):
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index, node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
in_iteration_id=event.in_iteration_id else {},
in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {}, if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error", else "Unknown error",
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner):
QueueTextChunkEvent( QueueTextChunkEvent(
text=event.chunk_content, text=event.chunk_content,
from_variable_selector=event.from_variable_selector, from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, NodeRunRetrieverResourceEvent): elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event( self._publish_event(
QueueRetrieverResourcesEvent( QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources, retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, ParallelBranchRunSucceededEvent): elif isinstance(event, ParallelBranchRunSucceededEvent):
@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner):
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id in_iteration_id=event.in_iteration_id,
) )
) )
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
error=event.error error=event.error,
) )
) )
elif isinstance(event, IterationRunStartedEvent): elif isinstance(event, IterationRunStartedEvent):
@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner):
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs, inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata metadata=event.metadata,
) )
) )
elif isinstance(event, IterationRunNextEvent): elif isinstance(event, IterationRunNextEvent):
@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner):
outputs=event.outputs, outputs=event.outputs,
metadata=event.metadata, metadata=event.metadata,
steps=event.steps, steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None error=event.error if isinstance(event, IterationRunFailedEvent) else None,
) )
) )
@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner):
# return workflow # return workflow
return workflow return workflow
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish( self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
event,
PublishFrom.APPLICATION_MANAGER
)

@ -30,169 +30,150 @@ _TEXT_COLOR_MAPPING = {
class WorkflowLoggingCallback(WorkflowCallback): class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None: def __init__(self) -> None:
self.current_node_id = None self.current_node_id = None
def on_event( def on_event(self, event: GraphEngineEvent) -> None:
self,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent): if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink') self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green') self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started( self.on_workflow_node_execute_started(event=event)
event=event
)
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded( self.on_workflow_node_execute_succeeded(event=event)
event=event
)
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed( self.on_workflow_node_execute_failed(event=event)
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk( self.on_node_text_chunk(event=event)
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started( self.on_workflow_parallel_started(event=event)
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed( self.on_workflow_parallel_completed(event=event)
event=event
)
elif isinstance(event, IterationRunStartedEvent): elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started( self.on_workflow_iteration_started(event=event)
event=event
)
elif isinstance(event, IterationRunNextEvent): elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next( self.on_workflow_iteration_next(event=event)
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed( self.on_workflow_iteration_completed(event=event)
event=event
)
else: else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue') self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started( def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
self,
event: NodeRunStartedEvent
) -> None:
""" """
Workflow node execute started Workflow node execute started
""" """
self.print_text("\n[NodeRunStartedEvent]", color='yellow') self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color='yellow') self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color='yellow') self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color='yellow') self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded( def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
self,
event: NodeRunSucceededEvent
) -> None:
""" """
Workflow node execute succeeded Workflow node execute succeeded
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color='green') self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color='green') self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color='green') self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color='green') self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result: if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green')
self.print_text( self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green') color="green",
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", )
color='green') self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text( self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green') color="green",
)
def on_workflow_node_execute_failed( def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
self,
event: NodeRunFailedEvent
) -> None:
""" """
Workflow node execute failed Workflow node execute failed
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color='red') self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color='red') self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color='red') self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color='red') self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result: if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red') self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red')
self.print_text( self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red') color="red",
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", )
color='red') self.print_text(
f"Process Data: "
def on_node_text_chunk( f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
self, color="red",
event: NodeRunStreamChunkEvent )
) -> None: self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
""" """
Publish text chunk Publish text chunk
""" """
route_node_state = event.route_node_state route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id: if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]') self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}") self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result node_run_result = route_node_state.node_run_result
if node_run_result: if node_run_result:
self.print_text( self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="") self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started( def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
self,
event: ParallelBranchRunStartedEvent
) -> None:
""" """
Publish parallel started Publish parallel started
""" """
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id: if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
def on_workflow_parallel_completed( def on_workflow_parallel_completed(
self, self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None: ) -> None:
""" """
Publish parallel completed Publish parallel completed
""" """
if isinstance(event, ParallelBranchRunSucceededEvent): if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue' color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red' color = "red"
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color) self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id: if event.in_iteration_id:
@ -201,43 +182,37 @@ class WorkflowLoggingCallback(WorkflowCallback):
if isinstance(event, ParallelBranchRunFailedEvent): if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color) self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started( def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
self,
event: IterationRunStartedEvent
) -> None:
""" """
Publish iteration started Publish iteration started
""" """
self.print_text("\n[IterationRunStartedEvent]", color='blue') self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next( def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
self,
event: IterationRunNextEvent
) -> None:
""" """
Publish iteration next Publish iteration next
""" """
self.print_text("\n[IterationRunNextEvent]", color='blue') self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color='blue') self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed( def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
""" """
Publish iteration completed Publish iteration completed
""" """
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') self.print_text(
self.print_text(f"Node ID: {event.iteration_id}", color='blue') "\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def print_text( def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
self, text: str, color: Optional[str] = None, end: str = "\n"
) -> None:
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text text_to_print = self._get_colored_text(text, color) if color else text
print(f'{text_to_print}', end=end) print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str: def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text.""" """Get colored text."""

@ -15,13 +15,14 @@ class InvokeFrom(Enum):
""" """
Invoke From. Invoke From.
""" """
SERVICE_API = 'service-api'
WEB_APP = 'web-app' SERVICE_API = "service-api"
EXPLORE = 'explore' WEB_APP = "web-app"
DEBUGGER = 'debugger' EXPLORE = "explore"
DEBUGGER = "debugger"
@classmethod @classmethod
def value_of(cls, value: str) -> 'InvokeFrom': def value_of(cls, value: str) -> "InvokeFrom":
""" """
Get value of given mode. Get value of given mode.
@ -31,7 +32,7 @@ class InvokeFrom(Enum):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f'invalid invoke from value {value}') raise ValueError(f"invalid invoke from value {value}")
def to_source(self) -> str: def to_source(self) -> str:
""" """
@ -40,21 +41,22 @@ class InvokeFrom(Enum):
:return: source :return: source
""" """
if self == InvokeFrom.WEB_APP: if self == InvokeFrom.WEB_APP:
return 'web_app' return "web_app"
elif self == InvokeFrom.DEBUGGER: elif self == InvokeFrom.DEBUGGER:
return 'dev' return "dev"
elif self == InvokeFrom.EXPLORE: elif self == InvokeFrom.EXPLORE:
return 'explore_app' return "explore_app"
elif self == InvokeFrom.SERVICE_API: elif self == InvokeFrom.SERVICE_API:
return 'api' return "api"
return 'dev' return "dev"
class ModelConfigWithCredentialsEntity(BaseModel): class ModelConfigWithCredentialsEntity(BaseModel):
""" """
Model Config With Credentials Entity. Model Config With Credentials Entity.
""" """
provider: str provider: str
model: str model: str
model_schema: AIModelEntity model_schema: AIModelEntity
@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
""" """
App Generate Entity. App Generate Entity.
""" """
task_id: str task_id: str
# app config # app config
@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
""" """
Chat Application Generate Entity. Chat Application Generate Entity.
""" """
# app config # app config
app_config: EasyUIBasedAppConfig app_config: EasyUIBasedAppConfig
model_conf: ModelConfigWithCredentialsEntity model_conf: ModelConfigWithCredentialsEntity
@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
""" """
Chat Application Generate Entity. Chat Application Generate Entity.
""" """
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
""" """
Completion Application Generate Entity. Completion Application Generate Entity.
""" """
pass pass
@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
""" """
Agent Chat Application Generate Entity. Agent Chat Application Generate Entity.
""" """
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
""" """
Advanced Chat Application Generate Entity. Advanced Chat Application Generate Entity.
""" """
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
""" """
Single Iteration Run Entity. Single Iteration Run Entity.
""" """
node_id: str node_id: str
inputs: dict inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity): class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
Workflow Application Generate Entity. Workflow Application Generate Entity.
""" """
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
""" """
Single Iteration Run Entity. Single Iteration Run Entity.
""" """
node_id: str node_id: str
inputs: dict inputs: dict

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save