chore: lack of the user message
parent
dd02a9ac9d
commit
48be8fb6cc
@ -0,0 +1,18 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMemory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_history_prompt_messages(self) -> Sequence[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Get the history prompt messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_history_prompt_text(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the history prompt text
|
||||||
|
"""
|
||||||
@ -0,0 +1,205 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from core.model_runtime.entities import (
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessageRole,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import LLMMemoryType
|
||||||
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Conversation, Message
|
||||||
|
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextMemory:
|
||||||
|
def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None:
|
||||||
|
self.conversation = conversation
|
||||||
|
self.node_id = node_id
|
||||||
|
self.model_instance = model_instance
|
||||||
|
|
||||||
|
def get_history_prompt_messages(
|
||||||
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Get history prompt messages.
|
||||||
|
:param max_token_limit: max token limit
|
||||||
|
:param message_limit: message limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# fetch limited messages, and return reversed
|
||||||
|
query = (
|
||||||
|
db.session.query(
|
||||||
|
Message.id,
|
||||||
|
Message.query,
|
||||||
|
Message.answer,
|
||||||
|
Message.created_at,
|
||||||
|
Message.workflow_run_id,
|
||||||
|
Message.parent_message_id,
|
||||||
|
Message.answer_tokens,
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
Message.conversation_id == self.conversation.id,
|
||||||
|
)
|
||||||
|
.order_by(Message.created_at.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
if message_limit and message_limit > 0:
|
||||||
|
message_limit = min(message_limit, 500)
|
||||||
|
else:
|
||||||
|
message_limit = 500
|
||||||
|
|
||||||
|
messages = query.limit(message_limit).all()
|
||||||
|
|
||||||
|
# instead of all messages from the conversation, we only need to extract messages
|
||||||
|
# that belong to the thread of last message
|
||||||
|
thread_messages = extract_thread_messages(messages)
|
||||||
|
|
||||||
|
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||||
|
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||||
|
thread_messages.pop(0)
|
||||||
|
if len(thread_messages) == 0:
|
||||||
|
return []
|
||||||
|
last_thread_message = list(reversed(thread_messages))[0]
|
||||||
|
last_node_execution = (
|
||||||
|
db.session.query(WorkflowNodeExecution)
|
||||||
|
.filter(
|
||||||
|
WorkflowNodeExecution.workflow_run_id == last_thread_message.workflow_run_id,
|
||||||
|
WorkflowNodeExecution.node_id == self.node_id,
|
||||||
|
WorkflowNodeExecution.status.in_(
|
||||||
|
[WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.order_by(WorkflowNodeExecution.created_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
||||||
|
# files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||||
|
# if files:
|
||||||
|
# file_extra_config = None
|
||||||
|
# if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
|
# file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||||
|
# else:
|
||||||
|
# if message.workflow_run_id:
|
||||||
|
# workflow_run = (
|
||||||
|
# db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if workflow_run and workflow_run.workflow:
|
||||||
|
# file_extra_config = FileUploadConfigManager.convert(
|
||||||
|
# workflow_run.workflow.features_dict, is_vision=False
|
||||||
|
# )
|
||||||
|
|
||||||
|
# detail = ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
# if file_extra_config and app_record:
|
||||||
|
# file_objs = file_factory.build_from_message_files(
|
||||||
|
# message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||||
|
# )
|
||||||
|
# if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||||
|
# detail = file_extra_config.image_config.detail
|
||||||
|
# else:
|
||||||
|
# file_objs = []
|
||||||
|
|
||||||
|
# if not file_objs:
|
||||||
|
# prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
# else:
|
||||||
|
# prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
|
# prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
|
# for file in file_objs:
|
||||||
|
# prompt_message = file_manager.to_prompt_message_content(
|
||||||
|
# file,
|
||||||
|
# image_detail_config=detail,
|
||||||
|
# )
|
||||||
|
# prompt_message_contents.append(prompt_message)
|
||||||
|
|
||||||
|
# prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
if last_node_execution and last_node_execution.process_data:
|
||||||
|
try:
|
||||||
|
process_data = json.loads(last_node_execution.process_data)
|
||||||
|
if process_data.get("memory_type", "") == LLMMemoryType.INDEPENDENT:
|
||||||
|
for prompt in process_data.get("prompts", []):
|
||||||
|
if prompt.get("role") == "user":
|
||||||
|
prompt_messages.append(
|
||||||
|
UserPromptMessage(
|
||||||
|
content=prompt.get("content"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif prompt.get("role") == "assistant":
|
||||||
|
prompt_messages.append(
|
||||||
|
AssistantPromptMessage(
|
||||||
|
content=prompt.get("content"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output = (
|
||||||
|
json.loads(last_node_execution.outputs).get("text", "") if last_node_execution.outputs else ""
|
||||||
|
)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=output))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not prompt_messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# prune the chat message if it exceeds the max token limit
|
||||||
|
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
|
if curr_message_tokens > max_token_limit:
|
||||||
|
pruned_memory = []
|
||||||
|
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||||
|
pruned_memory.append(prompt_messages.pop(0))
|
||||||
|
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
def get_history_prompt_text(
|
||||||
|
self,
|
||||||
|
human_prefix: str = "Human",
|
||||||
|
ai_prefix: str = "Assistant",
|
||||||
|
max_token_limit: int = 2000,
|
||||||
|
message_limit: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get history prompt text.
|
||||||
|
:param human_prefix: human prefix
|
||||||
|
:param ai_prefix: ai prefix
|
||||||
|
:param max_token_limit: max token limit
|
||||||
|
:param message_limit: message limit
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||||
|
|
||||||
|
string_messages = []
|
||||||
|
for m in prompt_messages:
|
||||||
|
if m.role == PromptMessageRole.USER:
|
||||||
|
role = human_prefix
|
||||||
|
elif m.role == PromptMessageRole.ASSISTANT:
|
||||||
|
role = ai_prefix
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(m.content, list):
|
||||||
|
inner_msg = ""
|
||||||
|
for content in m.content:
|
||||||
|
if isinstance(content, TextPromptMessageContent):
|
||||||
|
inner_msg += f"{content.data}\n"
|
||||||
|
elif isinstance(content, ImagePromptMessageContent):
|
||||||
|
inner_msg += "[image]\n"
|
||||||
|
|
||||||
|
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||||
|
else:
|
||||||
|
message = f"{role}: {m.content}"
|
||||||
|
string_messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(string_messages)
|
||||||
Loading…
Reference in New Issue