diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d97074e8b9..e1415e1c7a 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) from . import index -from .app import annotation, app, audio, completion, conversation, file, message, workflow +from .app import annotation, app, audio, completion, conversation, file, message, workflow, workflow_run from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 38a65b7a90..2b8e587c6a 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -108,6 +108,7 @@ class ChatApi(Resource): parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") + parser.add_argument("is_async", type=bool, required=False, default=False, location="json") args = parser.parse_args() diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 95e538f4c7..9779244666 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -11,10 +11,11 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import message_file_fields -from fields.message_fields import agent_thought_fields, feedback_fields +from fields.message_fields import agent_thought_fields, feedback_fields, message_fields from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value -from models.model import App, AppMode, EndUser +from models import db +from models.model import App, AppMode, EndUser, Message from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -116,6 +117,21 @@ class MessageSuggestedApi(Resource): return {"result": "success", "data": questions} +class MessageApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + @marshal_with(message_fields) + def get(self, app_model: App, end_user: EndUser, message_id): + message_id = str(message_id) + + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() + + if not message: + raise NotFound("Message Not Exists.") + + return message + + api.add_resource(MessageListApi, "/messages") api.add_resource(MessageFeedbackApi, "/messages//feedbacks") api.add_resource(MessageSuggestedApi, "/messages//suggested") +api.add_resource(MessageApi, "/apps/messages/") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 8b10a028f3..0e3897e0a8 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -80,6 +80,8 @@ class WorkflowRunApi(Resource): parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("is_async", type=bool, required=False, default=False, location="json") + args = parser.parse_args() streaming = args.get("response_mode") == "streaming" diff --git a/api/controllers/service_api/app/workflow_run.py b/api/controllers/service_api/app/workflow_run.py new file mode 100644 index 0000000000..746234d084 --- /dev/null +++ b/api/controllers/service_api/app/workflow_run.py @@ -0,0 +1,27 @@ +from flask_restful import Resource, marshal_with + +from controllers.service_api import api +from controllers.service_api.wraps import validate_app_token +from fields.workflow_run_fields import ( + workflow_run_node_execution_fields, +) +from models import App +from services.workflow_run_service import WorkflowRunService + + +class WorkflowRunDetailApi(Resource): + @validate_app_token + @marshal_with(workflow_run_node_execution_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + + return workflow_run + + +api.add_resource(WorkflowRunDetailApi, "/apps/workflow-runs/", endpoint="workflow_run_detail") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ef582d28e0..1b7998def1 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -100,7 +100,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): query = query.replace("\x00", "") inputs = args["inputs"] - extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False), + "is_async": args.get("is_async", False)} # get conversation conversation = None diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 3bf6c330db..b296a64fa3 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,16 +1,21 @@ +import contextvars import json import logging +import threading import time from collections.abc import Generator, Mapping from threading import Thread from typing import Any, Optional, Union +from flask import Flask, current_app + from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -46,6 +51,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.app.entities.queue_task_bridge import advance_chat_queue_task_map, ForwardQueueMessage from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, ChatbotAppStreamResponse, @@ -54,7 +60,7 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, - WorkflowTaskState, + WorkflowTaskState, MessageStreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage @@ -163,7 +169,21 @@ class AdvancedChatAppGenerateTaskPipeline: Process blocking response. :return: """ + is_async = self._application_generate_entity.extras.get("is_async", False) + for stream_response in generator: + if is_async: + return ChatbotAppBlockingResponse( + task_id=self._application_generate_entity.workflow_run_id, + data=ChatbotAppBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + answer=self._application_generate_entity.workflow_run_id, + created_at=int(self._message_created_at), + ), + ) if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): @@ -195,7 +215,20 @@ class AdvancedChatAppGenerateTaskPipeline: To stream response. :return: """ + is_async = self._application_generate_entity.extras.get("is_async", False) for stream_response in generator: + if is_async: + yield ChatbotAppStreamResponse( + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=int(self._message_created_at), + stream_response=MessageStreamResponse( + task_id=self._application_generate_entity.task_id, + id="0", + answer=self._application_generate_entity.workflow_run_id, + ), + ) + return yield ChatbotAppStreamResponse( conversation_id=self._conversation_id, message_id=self._message_id, @@ -228,14 +261,7 @@ class AdvancedChatAppGenerateTaskPipeline: tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") ) - for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): - while True: - audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) - if audio_response: - yield audio_response - else: - break - yield response + yield from self._async_process_stream_response(tts_publisher) start_listener_time = time.time() # timeout @@ -260,6 +286,67 @@ class AdvancedChatAppGenerateTaskPipeline: if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + def _async_process_stream_response(self, publisher): + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=self._application_generate_entity.task_id, + user_id=self._application_generate_entity.user_id, + invoke_from=self._application_generate_entity.invoke_from, + conversation_id=self._conversation_id, + app_mode=self._conversation_mode, + message_id=self._message_id, + ) + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'queue_manager': queue_manager, + 'context': contextvars.copy_context(), + 'publisher': publisher + }) + + worker_thread.start() + + yield from self._consumer_worker(queue_manager) + + def _consumer_worker(self, queue_manager: AppQueueManager) -> Generator[StreamResponse, None, None]: + for message in queue_manager.listen(): + event = message.event + if isinstance(event, ForwardQueueMessage): + yield event.response + + def _generate_worker(self, flask_app: Flask, + queue_manager: AppQueueManager, + context: contextvars.Context, publisher) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param queue_manager: queue manager + :return: + """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): + response_generator = self._sync_process_stream_response( + publisher, + ) + for generator in response_generator: + if generator is None: + continue + message = ForwardQueueMessage(event=advance_chat_queue_task_map[generator.event], response=generator) + queue_manager.publish(message, PublishFrom.TASK_PIPELINE) + + def _sync_process_stream_response(self, publisher): + + for response in self._process_stream_response(publisher, + trace_manager=self._application_generate_entity.trace_manager): + + while True: + audio_response = self._listen_audio_msg(publisher, task_id=self._application_generate_entity.task_id) + if audio_response: + yield audio_response + else: + break + yield response + def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 08986b16f0..8dee392fc1 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -127,6 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator): call_depth=call_depth, trace_manager=trace_manager, workflow_run_id=workflow_run_id, + extras={"is_async": args.get("is_async", False)}, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 1f998edb6a..e9744a2cc1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,13 +1,17 @@ +import contextvars import logging +import threading import time from collections.abc import Generator from typing import Optional, Union +from flask import Flask, current_app from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk -from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, @@ -39,6 +43,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.app.entities.queue_task_bridge import workflow_queue_task_map, ForwardQueueMessage from core.app.entities.task_entities import ( ErrorStreamResponse, MessageAudioEndStreamResponse, @@ -49,7 +54,7 @@ from core.app.entities.task_entities import ( WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, + WorkflowTaskState, MessageStreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage @@ -133,7 +138,19 @@ class WorkflowAppGenerateTaskPipeline: To blocking response. :return: """ + is_async = self._application_generate_entity.extras.get("is_async", False) for stream_response in generator: + if is_async: + return WorkflowAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=self._application_generate_entity.workflow_run_id, + data=WorkflowAppBlockingResponse.Data( + id=self._application_generate_entity.app_config.app_id, + workflow_id=self._workflow_id, + status='processing', + created_at=int(time.time()), + ), + ) if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): @@ -168,9 +185,21 @@ class WorkflowAppGenerateTaskPipeline: :return: """ workflow_run_id = None + is_async = self._application_generate_entity.extras.get("is_async", False) + for stream_response in generator: if isinstance(stream_response, WorkflowStartStreamResponse): workflow_run_id = stream_response.workflow_run_id + if is_async: + yield WorkflowAppStreamResponse( + workflow_run_id=workflow_run_id, + stream_response=MessageStreamResponse( + task_id=self._application_generate_entity.task_id, + id="0", + answer=self._application_generate_entity.workflow_run_id, + ), + ) + return yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) @@ -199,14 +228,7 @@ class WorkflowAppGenerateTaskPipeline: tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") ) - for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): - while True: - audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) - if audio_response: - yield audio_response - else: - break - yield response + yield from self._async_process_stream_response(tts_publisher) start_listener_time = time.time() while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: @@ -229,6 +251,64 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) + def _consumer_worker(self, queue_manager: AppQueueManager) -> Generator[StreamResponse, None, None]: + for message in queue_manager.listen(): + event = message.event + if isinstance(event, ForwardQueueMessage): + yield event.response + + def _generate_worker(self, flask_app: Flask, + queue_manager: AppQueueManager, + context: contextvars.Context, publisher) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param queue_manager: queue manager + :return: + """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): + response_generator = self._sync_process_stream_response( + publisher, + ) + for generator in response_generator: + if generator is None: + continue + message = ForwardQueueMessage(event=workflow_queue_task_map[generator.event], response=generator) + queue_manager.publish(message, PublishFrom.TASK_PIPELINE) + + def _async_process_stream_response(self, publisher): + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=self._application_generate_entity.task_id, + user_id=self._application_generate_entity.user_id, + invoke_from=self._application_generate_entity.invoke_from, + app_mode=self._application_generate_entity.app_config.app_mode, + ) + + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'queue_manager': queue_manager, + 'context': contextvars.copy_context(), + 'publisher': publisher + }) + + worker_thread.start() + + yield from self._consumer_worker(queue_manager) + + def _sync_process_stream_response(self, publisher): + for response in self._process_stream_response(publisher, + trace_manager=self._application_generate_entity.trace_manager): + while True: + audio_response = self._listen_audio_msg(publisher, task_id=self._application_generate_entity.task_id) + if audio_response: + yield audio_response + else: + break + yield response + def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, diff --git a/api/core/app/entities/queue_task_bridge.py b/api/core/app/entities/queue_task_bridge.py new file mode 100644 index 0000000000..2e68669ddf --- /dev/null +++ b/api/core/app/entities/queue_task_bridge.py @@ -0,0 +1,69 @@ +from core.app.entities.queue_entities import QueueEvent, AppQueueEvent +from core.app.entities.task_entities import StreamEvent, StreamResponse + +workflow_queue_task_map = { + StreamEvent.PING: QueueEvent.PING, + StreamEvent.ERROR: QueueEvent.ERROR, + StreamEvent.MESSAGE: QueueEvent.TEXT_CHUNK, + StreamEvent.MESSAGE_END: QueueEvent.MESSAGE_END, + StreamEvent.TTS_MESSAGE: QueueEvent.TEXT_CHUNK, + StreamEvent.TTS_MESSAGE_END: QueueEvent.MESSAGE_END, + StreamEvent.MESSAGE_FILE: QueueEvent.MESSAGE_FILE, + StreamEvent.MESSAGE_REPLACE: QueueEvent.MESSAGE_REPLACE, + StreamEvent.AGENT_THOUGHT: QueueEvent.AGENT_THOUGHT, + StreamEvent.AGENT_MESSAGE: QueueEvent.AGENT_MESSAGE, + StreamEvent.WORKFLOW_STARTED: QueueEvent.WORKFLOW_STARTED, + StreamEvent.WORKFLOW_FINISHED: QueueEvent.WORKFLOW_SUCCEEDED, + StreamEvent.NODE_STARTED: QueueEvent.NODE_STARTED, + StreamEvent.NODE_FINISHED: QueueEvent.NODE_SUCCEEDED, + StreamEvent.NODE_RETRY: QueueEvent.RETRY, + StreamEvent.PARALLEL_BRANCH_STARTED: QueueEvent.PARALLEL_BRANCH_RUN_STARTED, + StreamEvent.PARALLEL_BRANCH_FINISHED: QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED, + StreamEvent.ITERATION_STARTED: QueueEvent.ITERATION_START, + StreamEvent.ITERATION_NEXT: QueueEvent.ITERATION_NEXT, + StreamEvent.ITERATION_COMPLETED: QueueEvent.ITERATION_COMPLETED, + StreamEvent.LOOP_STARTED: QueueEvent.LOOP_START, + StreamEvent.LOOP_NEXT: QueueEvent.LOOP_NEXT, + StreamEvent.LOOP_COMPLETED: QueueEvent.LOOP_COMPLETED, + StreamEvent.TEXT_CHUNK: QueueEvent.TEXT_CHUNK, + StreamEvent.TEXT_REPLACE: QueueEvent.MESSAGE_REPLACE, + StreamEvent.AGENT_LOG: QueueEvent.AGENT_LOG, + +} + +advance_chat_queue_task_map = { + StreamEvent.PING: QueueEvent.PING, + StreamEvent.ERROR: QueueEvent.ERROR, + StreamEvent.MESSAGE: QueueEvent.TEXT_CHUNK, + StreamEvent.MESSAGE_END: QueueEvent.ADVANCED_CHAT_MESSAGE_END, + StreamEvent.TTS_MESSAGE: QueueEvent.TEXT_CHUNK, + StreamEvent.TTS_MESSAGE_END: QueueEvent.MESSAGE_END, + StreamEvent.MESSAGE_FILE: QueueEvent.MESSAGE_FILE, + StreamEvent.MESSAGE_REPLACE: QueueEvent.MESSAGE_REPLACE, + StreamEvent.AGENT_THOUGHT: QueueEvent.AGENT_THOUGHT, + StreamEvent.AGENT_MESSAGE: QueueEvent.AGENT_MESSAGE, + StreamEvent.WORKFLOW_STARTED: QueueEvent.WORKFLOW_STARTED, + StreamEvent.WORKFLOW_FINISHED: QueueEvent.WORKFLOW_SUCCEEDED, + StreamEvent.NODE_STARTED: QueueEvent.NODE_STARTED, + StreamEvent.NODE_FINISHED: QueueEvent.NODE_SUCCEEDED, + StreamEvent.NODE_RETRY: QueueEvent.RETRY, + StreamEvent.PARALLEL_BRANCH_STARTED: QueueEvent.PARALLEL_BRANCH_RUN_STARTED, + StreamEvent.PARALLEL_BRANCH_FINISHED: QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED, + StreamEvent.ITERATION_STARTED: QueueEvent.ITERATION_START, + StreamEvent.ITERATION_NEXT: QueueEvent.ITERATION_NEXT, + StreamEvent.ITERATION_COMPLETED: QueueEvent.ITERATION_COMPLETED, + StreamEvent.LOOP_STARTED: QueueEvent.LOOP_START, + StreamEvent.LOOP_NEXT: QueueEvent.LOOP_NEXT, + StreamEvent.LOOP_COMPLETED: QueueEvent.LOOP_COMPLETED, + StreamEvent.TEXT_CHUNK: QueueEvent.TEXT_CHUNK, + StreamEvent.TEXT_REPLACE: QueueEvent.MESSAGE_REPLACE, + StreamEvent.AGENT_LOG: QueueEvent.AGENT_LOG, +} + + +class ForwardQueueMessage(AppQueueEvent): + """ + ForwardQueueMessage entity + """ + event: QueueEvent = QueueEvent.PING + response: StreamResponse