diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..5249e1520c 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -16,7 +16,7 @@ from core.model_runtime.entities.message_entities import PromptMessageContentUni from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory -from models.model import AppMode, Conversation, Message, MessageFile +from models.model import AppMode, Conversation, Message, MessageFile, NodeFileUsage from models.workflow import WorkflowRun @@ -26,7 +26,11 @@ class TokenBufferMemory: self.model_instance = model_instance def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + *, + allowed_node_id: str | None = None, ) -> Sequence[PromptMessage]: """ Get history prompt messages. @@ -70,8 +74,45 @@ class TokenBufferMemory: messages = list(reversed(thread_messages)) prompt_messages: list[PromptMessage] = [] + + # Attachment filtering strategy + # • allowed_node_id is None → keep files used by ANY node (shared memory) + # • allowed_node_id == "" → strip ALL attachments + # • allowed_node_id == → keep only files this node used + + allowed_upload_ids: set[str] | None = None + + if allowed_node_id == "": + allowed_upload_ids = set() # keep none → strip all attachments + elif allowed_node_id is None: + # Shared memory: allow files used by any node + usage_rows = ( + db.session.query(NodeFileUsage.upload_file_id) + .filter(NodeFileUsage.conversation_id == self.conversation.id) + .all() + ) + allowed_upload_ids = {str(r[0]) for r in usage_rows} + else: + # Node-specific filtering + usage_rows = ( + db.session.query(NodeFileUsage.upload_file_id) + .filter( + NodeFileUsage.conversation_id == self.conversation.id, + NodeFileUsage.node_id == allowed_node_id, + ) + .all() + ) + allowed_upload_ids = {str(r[0]) for r in usage_rows} + for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + + # If attachment filtering is enabled, filter MessageFile list first. + if allowed_upload_ids is not None: + files = [ + f for f in files if f.upload_file_id and str(f.upload_file_id) in allowed_upload_ids + ] + if files: file_extra_config = None if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index ee1b1e1dd2..82130326ba 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -34,8 +34,7 @@ from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories.agent_factory import get_plugin_agent_strategy -from models.model import Conversation - +from models.model import Conversation, NodeFileUsage class AgentNode(ToolNode): """ @@ -257,35 +256,6 @@ class AgentNode(ToolNode): value = cast(dict[str, Any], value) model_instance, model_schema = self._fetch_model(value) history_prompt_messages: list[dict[str, Any]] = [] - if node_data.memory: - memory = self._fetch_memory(model_instance) - if memory: - prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size if node_data.memory.window.size else None - ) - history_prompt_messages = [ - prompt_message.model_dump(mode="json") for prompt_message in prompt_messages - ] - - # Strip file attachments from memory to prevent cross-turn leakage - def _strip_files_from_history(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - stripped: list[dict[str, Any]] = [] - for m in messages: - if m.get("role") == "user" and isinstance(m.get("content"), list): - # Keep only text content, remove file attachments - contents = [ - c - for c in m["content"] - if not (isinstance(c, dict) and c.get("type") and c.get("type") != "text") - ] - if not contents: - contents = [] - m = {**m, "content": contents} - stripped.append(m) - return stripped - - history_prompt_messages = _strip_files_from_history(history_prompt_messages) - # Check if this agent node references sys.files def _input_uses_sys_files(_agent_input): if _agent_input.type == "variable": @@ -298,9 +268,19 @@ class AgentNode(ToolNode): _input_uses_sys_files(inp) for inp in node_data.agent_parameters.values() ) - if uses_sys_files_for_node: - # History already stripped once above; ensure it's stripped (idempotent call) - history_prompt_messages = _strip_files_from_history(history_prompt_messages) + conv_var = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id_val = conv_var.value if isinstance(conv_var, StringSegment) else None + + if node_data.memory: + memory = self._fetch_memory(model_instance) + if memory: + prompt_messages = memory.get_history_prompt_messages( + message_limit=node_data.memory.window.size if node_data.memory.window.size else None, + allowed_node_id=None if uses_sys_files_for_node else "", + ) + history_prompt_messages = [ + prompt_message.model_dump(mode="json") for prompt_message in prompt_messages + ] if model_schema: # remove structured output feature to support old version agent plugin @@ -340,9 +320,41 @@ class AgentNode(ToolNode): continue # Append synthetic user message with current turn attachments - synthetic_user_prompt = UserPromptMessage(content=prompt_contents) + synthetic_user_prompt = UserPromptMessage( + content=prompt_contents, + name=f"__SYS_FILES__|{self.node_id}", # mark with node id for future filtering + ) history_prompt_messages.append(synthetic_user_prompt.model_dump(mode="json")) + # Persist usage to durable table + try: + if conversation_id_val: + with Session(db.engine) as session: + for f in files: + upload_id = getattr(f, "related_id", None) + if not upload_id: + continue + + # Upsert-ish: skip if record already exists + exists_stmt = select(NodeFileUsage.id).where( + NodeFileUsage.conversation_id == conversation_id_val, + NodeFileUsage.node_id == self.node_id, + NodeFileUsage.upload_file_id == upload_id, + ) + if not session.scalar(exists_stmt): + session.add( + NodeFileUsage( + conversation_id=conversation_id_val, + node_id=self.node_id, + upload_file_id=upload_id, + message_id=None, + ) + ) + session.commit() + except Exception: + # ignore if failed to persist + pass + value["history_prompt_messages"] = history_prompt_messages result[parameter_name] = value diff --git a/api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py b/api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py new file mode 100644 index 0000000000..f47e8fd10c --- /dev/null +++ b/api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py @@ -0,0 +1,36 @@ +"""add node_file_usage table + +Revision ID: b2a0bfccd123 +Revises: 0ab65e1cc7fa +Create Date: 2025-07-07 11:52:00 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b2a0bfccd123' +down_revision = '0ab65e1cc7fa' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "node_file_usage", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("uuid_generate_v4()")), + sa.Column("conversation_id", postgresql.UUID(as_uuid=True), nullable=False, index=True), + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("node_id", sa.String(length=64), nullable=False, index=True), + sa.Column("upload_file_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"],), + sa.ForeignKeyConstraint(["message_id"], ["messages.id"],), + sa.ForeignKeyConstraint(["upload_file_id"], ["upload_files.id"],), + ) + + +def downgrade() -> None: + op.drop_table("node_file_usage") \ No newline at end of file diff --git a/api/models/__init__.py b/api/models/__init__.py index 83b50eb099..cc9faeeefb 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -48,6 +48,7 @@ from .model import ( MessageChain, MessageFeedback, MessageFile, + NodeFileUsage, OperationLog, RecommendedApp, Site, @@ -138,6 +139,7 @@ __all__ = [ "MessageChain", "MessageFeedback", "MessageFile", + "NodeFileUsage", "OperationLog", "PinnedConversation", "Provider", diff --git a/api/models/model.py b/api/models/model.py index 93737043d5..9dbe6e8726 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1747,6 +1747,27 @@ class MessageAgentThought(Base): return {} +class NodeFileUsage(Base): + __tablename__ = "node_file_usage" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="node_file_usage_pkey"), + db.Index("node_file_usage_conversation_id_idx", "conversation_id"), + db.Index("node_file_usage_node_id_idx", "node_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + message_id = db.Column(StringUUID, db.ForeignKey("messages.id"), nullable=True) + node_id = db.Column(db.String(64), nullable=False) + upload_file_id = db.Column(StringUUID, db.ForeignKey("upload_files.id"), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + def __repr__(self) -> str: # pragma: no cover + return "".format( + cid=self.conversation_id, nid=self.node_id, fid=self.upload_file_id + ) + + class DatasetRetrieverResource(Base): __tablename__ = "dataset_retriever_resources" __table_args__ = (