chore: lack of the user message

feat/model-memory
Novice 10 months ago
parent dd02a9ac9d
commit 48be8fb6cc

@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from core.model_runtime.entities.message_entities import PromptMessage
class BaseMemory(ABC):
@abstractmethod
def get_history_prompt_messages(self) -> Sequence[PromptMessage]:
"""
Get the history prompt messages
"""
@abstractmethod
def get_history_prompt_text(self) -> str:
"""
Get the history prompt text
"""

@ -0,0 +1,205 @@
import json
from collections.abc import Sequence
from typing import Optional
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
UserPromptMessage,
)
from core.prompt.entities.advanced_prompt_entities import LLMMemoryType
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Conversation, Message
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus
class ModelContextMemory:
def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None:
self.conversation = conversation
self.node_id = node_id
self.model_instance = model_instance
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
"""
# fetch limited messages, and return reversed
query = (
db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id,
Message.parent_message_id,
Message.answer_tokens,
)
.filter(
Message.conversation_id == self.conversation.id,
)
.order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0:
message_limit = min(message_limit, 500)
else:
message_limit = 500
messages = query.limit(message_limit).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
if len(thread_messages) == 0:
return []
last_thread_message = list(reversed(thread_messages))[0]
last_node_execution = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.workflow_run_id == last_thread_message.workflow_run_id,
WorkflowNodeExecution.node_id == self.node_id,
WorkflowNodeExecution.status.in_(
[WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION]
),
)
.order_by(WorkflowNodeExecution.created_at.desc())
.first()
)
prompt_messages: list[PromptMessage] = []
# files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
# if files:
# file_extra_config = None
# if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
# file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
# else:
# if message.workflow_run_id:
# workflow_run = (
# db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
# )
# if workflow_run and workflow_run.workflow:
# file_extra_config = FileUploadConfigManager.convert(
# workflow_run.workflow.features_dict, is_vision=False
# )
# detail = ImagePromptMessageContent.DETAIL.LOW
# if file_extra_config and app_record:
# file_objs = file_factory.build_from_message_files(
# message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
# )
# if file_extra_config.image_config and file_extra_config.image_config.detail:
# detail = file_extra_config.image_config.detail
# else:
# file_objs = []
# if not file_objs:
# prompt_messages.append(UserPromptMessage(content=message.query))
# else:
# prompt_message_contents: list[PromptMessageContentUnionTypes] = []
# prompt_message_contents.append(TextPromptMessageContent(data=message.query))
# for file in file_objs:
# prompt_message = file_manager.to_prompt_message_content(
# file,
# image_detail_config=detail,
# )
# prompt_message_contents.append(prompt_message)
# prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
# else:
# prompt_messages.append(UserPromptMessage(content=message.query))
if last_node_execution and last_node_execution.process_data:
try:
process_data = json.loads(last_node_execution.process_data)
if process_data.get("memory_type", "") == LLMMemoryType.INDEPENDENT:
for prompt in process_data.get("prompts", []):
if prompt.get("role") == "user":
prompt_messages.append(
UserPromptMessage(
content=prompt.get("content"),
)
)
elif prompt.get("role") == "assistant":
prompt_messages.append(
AssistantPromptMessage(
content=prompt.get("content"),
)
)
output = (
json.loads(last_node_execution.outputs).get("text", "") if last_node_execution.outputs else ""
)
prompt_messages.append(AssistantPromptMessage(content=output))
except json.JSONDecodeError:
pass
if not prompt_messages:
return []
# prune the chat message if it exceeds the max token limit
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None,
) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
:param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

@ -1,3 +1,4 @@
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel
@ -24,6 +25,11 @@ class CompletionModelPromptTemplate(BaseModel):
edition_type: Optional[Literal["basic", "jinja2"]] = None
class LLMMemoryType(str, Enum):
INDEPENDENT = "independent"
GLOBAL = "global"
class MemoryConfig(BaseModel):
"""
Memory Config.
@ -48,3 +54,4 @@ class MemoryConfig(BaseModel):
role_prefix: Optional[RolePrefix] = None
window: WindowConfig
query_prompt_template: Optional[str] = None
type: LLMMemoryType = LLMMemoryType.GLOBAL

@ -1,9 +1,8 @@
from typing import Any
from constants import UUID_NIL
from models.model import Message
def extract_thread_messages(messages: list[Any]):
def extract_thread_messages(messages: list[Message]) -> list[Message]:
thread_messages = []
next_message = None

@ -13,6 +13,7 @@ from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.model_context_memory import ModelContextMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -39,7 +40,7 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, LLMMemoryType, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import (
ArrayAnySegment,
@ -190,6 +191,7 @@ class LLMNode(BaseNode[LLMNodeData]):
),
"model_provider": model_config.provider,
"model_name": model_config.model,
"memory_type": self.node_data.memory.type if self.node_data.memory else None,
}
# handle invoke result
@ -553,10 +555,9 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
) -> Optional[TokenBufferMemory | ModelContextMemory]:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
@ -575,7 +576,15 @@ class LLMNode(BaseNode[LLMNodeData]):
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
memory = (
TokenBufferMemory(conversation=conversation, model_instance=model_instance)
if node_data_memory.type == LLMMemoryType.GLOBAL
else ModelContextMemory(
conversation=conversation,
node_id=self.node_id,
model_instance=model_instance,
)
)
return memory
@ -585,7 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]):
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
memory: TokenBufferMemory | ModelContextMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
@ -1201,7 +1210,7 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory: TokenBufferMemory | ModelContextMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
@ -1218,7 +1227,7 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory: TokenBufferMemory | ModelContextMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:

@ -4,12 +4,13 @@ import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import produce from 'immer'
import type { Memory } from '../../../types'
import { MemoryRole } from '../../../types'
import { LLMMemoryType, MemoryRole } from '../../../types'
import cn from '@/utils/classnames'
import Field from '@/app/components/workflow/nodes/_base/components/field'
import Switch from '@/app/components/base/switch'
import Slider from '@/app/components/base/slider'
import Input from '@/app/components/base/input'
import { SimpleSelect } from '@/app/components/base/select'
const i18nPrefix = 'workflow.nodes.common.memory'
const WINDOW_SIZE_MIN = 1
@ -54,6 +55,7 @@ type Props = {
const MEMORY_DEFAULT: Memory = {
window: { enabled: false, size: WINDOW_SIZE_DEFAULT },
query_prompt_template: '{{#sys.query#}}',
type: LLMMemoryType.GLOBAL,
}
const MemoryConfig: FC<Props> = ({
@ -178,6 +180,23 @@ const MemoryConfig: FC<Props> = ({
/>
</div>
</div>
<div>
<SimpleSelect
items={[{
value: LLMMemoryType.INDEPENDENT,
name: 'Individual memory',
}, {
value: LLMMemoryType.GLOBAL,
name: 'Global memory',
}]}
onSelect={(value) => {
const newPayload = produce(payload || MEMORY_DEFAULT, (draft) => {
draft.type = value.value as LLMMemoryType
})
onChange(newPayload)
}}
/>
</div>
{canSetRoleName && (
<div className='mt-4'>
<div className='text-xs font-medium uppercase leading-6 text-text-tertiary'>{t(`${i18nPrefix}.conversationRoleName`)}</div>

@ -234,6 +234,11 @@ export type RolePrefix = {
assistant: string
}
export enum LLMMemoryType {
INDEPENDENT = 'independent',
GLOBAL = 'global',
}
export type Memory = {
role_prefix?: RolePrefix
window: {
@ -241,6 +246,7 @@ export type Memory = {
size: number | string | null
}
query_prompt_template: string
type: LLMMemoryType
}
export enum VarType {

Loading…
Cancel
Save