diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 09220de6e6..d8bcd84b51 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -164,8 +164,7 @@ class AdvancedChatAppGenerateTaskPipeline: conversation_id=self._conversation_id, query=self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: @@ -186,7 +185,7 @@ class AdvancedChatAppGenerateTaskPipeline: # Retrieve outputs from task state metadata, which is populated earlier final_outputs = {} - if self._task_state.metadata and hasattr(self._task_state.metadata, 'outputs'): + if self._task_state.metadata and hasattr(self._task_state.metadata, "outputs"): final_outputs = self._task_state.metadata.outputs return ChatbotAppBlockingResponse( @@ -244,14 +243,12 @@ class AdvancedChatAppGenerateTaskPipeline: and features_dict["text_to_speech"].get("autoPlay") == "enabled" ): tts_publisher = AppGeneratorTTSPublisher( - tenant_id, features_dict["text_to_speech"].get( - "voice"), features_dict["text_to_speech"].get("language") + tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") ) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg( - publisher=tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -276,8 +273,7 @@ class AdvancedChatAppGenerateTaskPipeline: start_listener_time = time.time() yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception: - logger.exception( - f"Failed to listen audio message, task_id: {task_id}") + logger.exception(f"Failed to listen audio message, task_id: {task_id}") break if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) @@ -317,8 +313,7 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_run_id = workflow_execution.id_ message = self._get_message(session=session) if not message: - raise ValueError( - f"Message not found: {self._message_id}") + raise ValueError(f"Message not found: {self._message_id}") message.workflow_run_id = workflow_execution.id_ workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, @@ -367,8 +362,7 @@ class AdvancedChatAppGenerateTaskPipeline: # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend( - self._workflow_response_converter.fetch_files_from_node_outputs( - event.outputs or {}) + self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) ) with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( @@ -516,10 +510,8 @@ class AdvancedChatAppGenerateTaskPipeline: task_id=self._application_generate_entity.task_id, workflow_execution=workflow_execution, ) - workflow_outputs_data = workflow_finish_resp.data.outputs.get( - 'outputs', {}) - self._task_state.metadata.outputs = workflow_outputs_data.get( - 'outputs') + workflow_outputs_data = workflow_finish_resp.data.outputs.get("outputs", {}) + self._task_state.metadata.outputs = workflow_outputs_data.get("outputs") yield workflow_finish_resp self._base_task_pipeline._queue_manager.publish( QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE @@ -572,8 +564,7 @@ class AdvancedChatAppGenerateTaskPipeline: task_id=self._application_generate_entity.task_id, workflow_execution=workflow_execution, ) - err_event = QueueErrorEvent(error=ValueError( - f"Run failed: {workflow_execution.error_message}")) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) err = self._base_task_pipeline._handle_error( event=err_event, session=session, message_id=self._message_id ) @@ -599,8 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_execution=workflow_execution, ) # Save message - self._save_message( - session=session, graph_runtime_state=graph_runtime_state) + self._save_message(session=session, graph_runtime_state=graph_runtime_state) session.commit() yield workflow_finish_resp @@ -636,8 +626,7 @@ class AdvancedChatAppGenerateTaskPipeline: continue # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk( - delta_text) + should_direct_answer = self._handle_output_moderation_chunk(delta_text) if should_direct_answer: continue @@ -669,8 +658,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) # Save message with Session(db.engine, expire_on_commit=False) as session: - self._save_message( - session=session, graph_runtime_state=graph_runtime_state) + self._save_message(session=session, graph_runtime_state=graph_runtime_state) session.commit() yield self._message_end_to_stream_response() @@ -691,8 +679,7 @@ class AdvancedChatAppGenerateTaskPipeline: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: message = self._get_message(session=session) message.answer = self._task_state.answer - message.provider_response_latency = time.perf_counter() - \ - self._base_task_pipeline._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( @@ -757,18 +744,15 @@ class AdvancedChatAppGenerateTaskPipeline: # stop subscribe new token when output moderation should direct output self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() self._base_task_pipeline._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer), PublishFrom.TASK_PIPELINE + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) self._base_task_pipeline._queue_manager.publish( - QueueStopEvent( - stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: - self._base_task_pipeline._output_moderation_handler.append_new_token( - text) + self._base_task_pipeline._output_moderation_handler.append_new_token(text) return False