From d595f74a3ddece28f09960d3c764130e2aa8890a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 30 May 2025 06:15:27 +0800 Subject: [PATCH] refactor: Refactors message cycle management Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 16 ++++++------ .../apps/workflow/generate_task_pipeline.py | 4 --- .../easy_ui_based_generate_task_pipeline.py | 25 ++++++++++++------- ...cle_manage.py => message_cycle_manager.py} | 14 +++++------ 4 files changed, 31 insertions(+), 28 deletions(-) rename api/core/app/task_pipeline/{message_cycle_manage.py => message_cycle_manager.py} (91%) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ffce11187b..237089df45 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -57,7 +57,7 @@ from core.app.entities.task_entities import ( WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder @@ -141,7 +141,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) self._task_state = WorkflowTaskState() - self._message_cycle_manager = MessageCycleManage( + self._message_cycle_manager = MessageCycleManager( application_generate_entity=application_generate_entity, task_state=self._task_state ) @@ -162,7 +162,7 @@ class AdvancedChatAppGenerateTaskPipeline: :return: """ # start generate conversation name thread - self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) @@ -605,7 +605,7 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): - self._message_cycle_manager._handle_retriever_resources(event) + self._message_cycle_manager.handle_retriever_resources(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) @@ -614,7 +614,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) session.commit() elif isinstance(event, QueueAnnotationReplyEvent): - self._message_cycle_manager._handle_annotation_reply(event) + self._message_cycle_manager.handle_annotation_reply(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) @@ -637,12 +637,12 @@ class AdvancedChatAppGenerateTaskPipeline: tts_publisher.publish(queue_message) self._task_state.answer += delta_text - yield self._message_cycle_manager._message_to_stream_response( + yield self._message_cycle_manager.message_to_stream_response( answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_cycle_manager._message_replace_to_stream_response( + yield self._message_cycle_manager.message_replace_to_stream_response( answer=event.text, reason=event.reason ) elif isinstance(event, QueueAdvancedChatMessageEndEvent): @@ -654,7 +654,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) if output_moderation_answer: self._task_state.answer = output_moderation_answer - yield self._message_cycle_manager._message_replace_to_stream_response( + yield self._message_cycle_manager.message_replace_to_stream_response( answer=output_moderation_answer, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e678774fae..9af471cab7 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -50,7 +50,6 @@ from core.app.entities.task_entities import ( WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk @@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline: ) self._application_generate_entity = application_generate_entity - self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._task_state = WorkflowTaskState() self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(queue_message) - self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( delta_text, from_variable_selector=event.from_variable_selector ) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6c768fd86c..6156b2973e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -43,7 +43,7 @@ from core.app.entities.task_entities import ( StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -63,7 +63,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) -class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): +class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -104,6 +104,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) ) + self._message_cycle_manager = MessageCycleManager( + application_generate_entity=application_generate_entity, + task_state=self._task_state, + ) + self._conversation_name_generate_thread: Optional[Thread] = None def process( @@ -115,7 +120,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ]: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) @@ -277,7 +282,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=output_moderation_answer + ) with Session(db.engine) as session: # Save message @@ -286,9 +293,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_end_resp = self._message_end_to_stream_response() yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) + self._message_cycle_manager.handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): - annotation = self._handle_annotation_reply(event) + annotation = self._message_cycle_manager.handle_annotation_reply(event) if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): @@ -296,7 +303,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if agent_thought_response is not None: yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): - response = self._message_file_to_stream_response(event) + response = self._message_cycle_manager.message_file_to_stream_response(event) if response: yield response elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): @@ -318,7 +325,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response( + yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, ) @@ -328,7 +335,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message_id, ) elif isinstance(event, QueueMessageReplaceEvent): - yield self._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): yield self._ping_stream_response() else: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manager.py similarity index 91% rename from api/core/app/task_pipeline/message_cycle_manage.py rename to api/core/app/task_pipeline/message_cycle_manager.py index a6d826f08b..8d762a6655 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -30,7 +30,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService -class MessageCycleManage: +class MessageCycleManager: def __init__( self, *, @@ -45,7 +45,7 @@ class MessageCycleManage: self._application_generate_entity = application_generate_entity self._task_state = task_state - def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation_id: conversation id @@ -102,7 +102,7 @@ class MessageCycleManage: db.session.commit() db.session.close() - def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: """ Handle annotation reply. :param event: event @@ -120,7 +120,7 @@ class MessageCycleManage: return None - def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: + def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: """ Handle retriever resources. :param event: event @@ -129,7 +129,7 @@ class MessageCycleManage: if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata["retriever_resources"] = event.retriever_resources - def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ Message file to stream response. :param event: event @@ -166,7 +166,7 @@ class MessageCycleManage: return None - def _message_to_stream_response( + def message_to_stream_response( self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None ) -> MessageStreamResponse: """ @@ -182,7 +182,7 @@ class MessageCycleManage: from_variable_selector=from_variable_selector, ) - def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: + def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: """ Message replace to stream response. :param answer: answer