From 48be8fb6cc224ab76a15a47c5c6828a8c4f4fc3e Mon Sep 17 00:00:00 2001 From: Novice Date: Fri, 25 Apr 2025 09:33:41 +0800 Subject: [PATCH] chore: lack of the user message --- api/core/memory/base_memory.py | 18 ++ api/core/memory/model_context_memory.py | 205 ++++++++++++++++++ .../entities/advanced_prompt_entities.py | 7 + .../prompt/utils/extract_thread_messages.py | 5 +- api/core/workflow/nodes/llm/node.py | 23 +- .../nodes/_base/components/memory-config.tsx | 21 +- web/app/components/workflow/types.ts | 6 + 7 files changed, 274 insertions(+), 11 deletions(-) create mode 100644 api/core/memory/base_memory.py create mode 100644 api/core/memory/model_context_memory.py diff --git a/api/core/memory/base_memory.py b/api/core/memory/base_memory.py new file mode 100644 index 0000000000..259d6d6a59 --- /dev/null +++ b/api/core/memory/base_memory.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from core.model_runtime.entities.message_entities import PromptMessage + + +class BaseMemory(ABC): + @abstractmethod + def get_history_prompt_messages(self) -> Sequence[PromptMessage]: + """ + Get the history prompt messages + """ + + @abstractmethod + def get_history_prompt_text(self) -> str: + """ + Get the history prompt text + """ diff --git a/api/core/memory/model_context_memory.py b/api/core/memory/model_context_memory.py new file mode 100644 index 0000000000..3dc58e4baf --- /dev/null +++ b/api/core/memory/model_context_memory.py @@ -0,0 +1,205 @@ +import json +from collections.abc import Sequence +from typing import Optional + +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + UserPromptMessage, +) +from core.prompt.entities.advanced_prompt_entities import LLMMemoryType +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from models.model import Conversation, Message +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus + + +class ModelContextMemory: + def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None: + self.conversation = conversation + self.node_id = node_id + self.model_instance = model_instance + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + """ + Get history prompt messages. + :param max_token_limit: max token limit + :param message_limit: message limit + """ + + # fetch limited messages, and return reversed + query = ( + db.session.query( + Message.id, + Message.query, + Message.answer, + Message.created_at, + Message.workflow_run_id, + Message.parent_message_id, + Message.answer_tokens, + ) + .filter( + Message.conversation_id == self.conversation.id, + ) + .order_by(Message.created_at.desc()) + ) + + if message_limit and message_limit > 0: + message_limit = min(message_limit, 500) + else: + message_limit = 500 + + messages = query.limit(message_limit).all() + + # instead of all messages from the conversation, we only need to extract messages + # that belong to the thread of last message + thread_messages = extract_thread_messages(messages) + + # for newly created message, its answer is temporarily empty, we don't need to add it to memory + if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: + thread_messages.pop(0) + if len(thread_messages) == 0: + return [] + last_thread_message = list(reversed(thread_messages))[0] + last_node_execution = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.workflow_run_id == last_thread_message.workflow_run_id, + WorkflowNodeExecution.node_id == self.node_id, + WorkflowNodeExecution.status.in_( + [WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION] + ), + ) + .order_by(WorkflowNodeExecution.created_at.desc()) + .first() + ) + prompt_messages: list[PromptMessage] = [] + + # files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + # if files: + # file_extra_config = None + # if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + # file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + # else: + # if message.workflow_run_id: + # workflow_run = ( + # db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + # ) + + # if workflow_run and workflow_run.workflow: + # file_extra_config = FileUploadConfigManager.convert( + # workflow_run.workflow.features_dict, is_vision=False + # ) + + # detail = ImagePromptMessageContent.DETAIL.LOW + # if file_extra_config and app_record: + # file_objs = file_factory.build_from_message_files( + # message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config + # ) + # if file_extra_config.image_config and file_extra_config.image_config.detail: + # detail = file_extra_config.image_config.detail + # else: + # file_objs = [] + + # if not file_objs: + # prompt_messages.append(UserPromptMessage(content=message.query)) + # else: + # prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + # prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + # for file in file_objs: + # prompt_message = file_manager.to_prompt_message_content( + # file, + # image_detail_config=detail, + # ) + # prompt_message_contents.append(prompt_message) + + # prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + + # else: + # prompt_messages.append(UserPromptMessage(content=message.query)) + if last_node_execution and last_node_execution.process_data: + try: + process_data = json.loads(last_node_execution.process_data) + if process_data.get("memory_type", "") == LLMMemoryType.INDEPENDENT: + for prompt in process_data.get("prompts", []): + if prompt.get("role") == "user": + prompt_messages.append( + UserPromptMessage( + content=prompt.get("content"), + ) + ) + elif prompt.get("role") == "assistant": + prompt_messages.append( + AssistantPromptMessage( + content=prompt.get("content"), + ) + ) + output = ( + json.loads(last_node_execution.outputs).get("text", "") if last_node_execution.outputs else "" + ) + prompt_messages.append(AssistantPromptMessage(content=output)) + except json.JSONDecodeError: + pass + + if not prompt_messages: + return [] + + # prune the chat message if it exceeds the max token limit + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + + if curr_message_tokens > max_token_limit: + pruned_memory = [] + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: + pruned_memory.append(prompt_messages.pop(0)) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + + return prompt_messages + + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: + """ + Get history prompt text. + :param human_prefix: human prefix + :param ai_prefix: ai prefix + :param max_token_limit: max token limit + :param message_limit: message limit + :return: + """ + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) + + string_messages = [] + for m in prompt_messages: + if m.role == PromptMessageRole.USER: + role = human_prefix + elif m.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) + + return "\n".join(string_messages) diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index c8e7b414df..fb11297e92 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Literal, Optional from pydantic import BaseModel @@ -24,6 +25,11 @@ class CompletionModelPromptTemplate(BaseModel): edition_type: Optional[Literal["basic", "jinja2"]] = None +class LLMMemoryType(str, Enum): + INDEPENDENT = "independent" + GLOBAL = "global" + + class MemoryConfig(BaseModel): """ Memory Config. @@ -48,3 +54,4 @@ class MemoryConfig(BaseModel): role_prefix: Optional[RolePrefix] = None window: WindowConfig query_prompt_template: Optional[str] = None + type: LLMMemoryType = LLMMemoryType.GLOBAL diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index f7aef76c87..f73b54069c 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,9 +1,8 @@ -from typing import Any - from constants import UUID_NIL +from models.model import Message -def extract_thread_messages(messages: list[Any]): +def extract_thread_messages(messages: list[Message]) -> list[Message]: thread_messages = [] next_message = None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1089e7168e..21dedc6ace 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -13,6 +13,7 @@ from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage +from core.memory.model_context_memory import ModelContextMemory from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( @@ -39,7 +40,7 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, LLMMemoryType, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables import ( ArrayAnySegment, @@ -190,6 +191,7 @@ class LLMNode(BaseNode[LLMNodeData]): ), "model_provider": model_config.provider, "model_name": model_config.model, + "memory_type": self.node_data.memory.type if self.node_data.memory else None, } # handle invoke result @@ -553,10 +555,9 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_memory( self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance - ) -> Optional[TokenBufferMemory]: + ) -> Optional[TokenBufferMemory | ModelContextMemory]: if not node_data_memory: return None - # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID.value] @@ -575,7 +576,15 @@ class LLMNode(BaseNode[LLMNodeData]): if not conversation: return None - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + memory = ( + TokenBufferMemory(conversation=conversation, model_instance=model_instance) + if node_data_memory.type == LLMMemoryType.GLOBAL + else ModelContextMemory( + conversation=conversation, + node_id=self.node_id, + model_instance=model_instance, + ) + ) return memory @@ -585,7 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]): sys_query: str | None = None, sys_files: Sequence["File"], context: str | None = None, - memory: TokenBufferMemory | None = None, + memory: TokenBufferMemory | ModelContextMemory | None = None, model_config: ModelConfigWithCredentialsEntity, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, memory_config: MemoryConfig | None = None, @@ -1201,7 +1210,7 @@ def _calculate_rest_token( def _handle_memory_chat_mode( *, - memory: TokenBufferMemory | None, + memory: TokenBufferMemory | ModelContextMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: @@ -1218,7 +1227,7 @@ def _handle_memory_chat_mode( def _handle_memory_completion_mode( *, - memory: TokenBufferMemory | None, + memory: TokenBufferMemory | ModelContextMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> str: diff --git a/web/app/components/workflow/nodes/_base/components/memory-config.tsx b/web/app/components/workflow/nodes/_base/components/memory-config.tsx index 446fcfa8ae..23edcb66ba 100644 --- a/web/app/components/workflow/nodes/_base/components/memory-config.tsx +++ b/web/app/components/workflow/nodes/_base/components/memory-config.tsx @@ -4,12 +4,13 @@ import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' import produce from 'immer' import type { Memory } from '../../../types' -import { MemoryRole } from '../../../types' +import { LLMMemoryType, MemoryRole } from '../../../types' import cn from '@/utils/classnames' import Field from '@/app/components/workflow/nodes/_base/components/field' import Switch from '@/app/components/base/switch' import Slider from '@/app/components/base/slider' import Input from '@/app/components/base/input' +import { SimpleSelect } from '@/app/components/base/select' const i18nPrefix = 'workflow.nodes.common.memory' const WINDOW_SIZE_MIN = 1 @@ -54,6 +55,7 @@ type Props = { const MEMORY_DEFAULT: Memory = { window: { enabled: false, size: WINDOW_SIZE_DEFAULT }, query_prompt_template: '{{#sys.query#}}', + type: LLMMemoryType.GLOBAL, } const MemoryConfig: FC = ({ @@ -178,6 +180,23 @@ const MemoryConfig: FC = ({ /> +
+ { + const newPayload = produce(payload || MEMORY_DEFAULT, (draft) => { + draft.type = value.value as LLMMemoryType + }) + onChange(newPayload) + }} + /> +
{canSetRoleName && (
{t(`${i18nPrefix}.conversationRoleName`)}
diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 884bdfbd10..6de9ac8817 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -234,6 +234,11 @@ export type RolePrefix = { assistant: string } +export enum LLMMemoryType { + INDEPENDENT = 'independent', + GLOBAL = 'global', +} + export type Memory = { role_prefix?: RolePrefix window: { @@ -241,6 +246,7 @@ export type Memory = { size: number | string | null } query_prompt_template: string + type: LLMMemoryType } export enum VarType {