From b448a2fccf6f350ffdd5cf7b3f94995736a0562c Mon Sep 17 00:00:00 2001 From: Kalo Chin Date: Mon, 7 Jul 2025 23:23:20 +0900 Subject: [PATCH] feat: Add node file usage tracking and filtering Introduces the NodeFileUsage model and migration to track file usage per node in conversations. Updates token buffer memory and agent node logic to filter file attachments based on node context, preventing cross-node file leakage. Also registers the new model in the models package. --- api/core/memory/token_buffer_memory.py | 45 +++++++++- api/core/workflow/nodes/agent/agent_node.py | 82 +++++++++++-------- ...7_1152-b2a0bfccd123_add_node_file_usage.py | 36 ++++++++ api/models/__init__.py | 2 + api/models/model.py | 21 +++++ 5 files changed, 149 insertions(+), 37 deletions(-) create mode 100644 api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..5249e1520c 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -16,7 +16,7 @@ from core.model_runtime.entities.message_entities import PromptMessageContentUni from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory -from models.model import AppMode, Conversation, Message, MessageFile +from models.model import AppMode, Conversation, Message, MessageFile, NodeFileUsage from models.workflow import WorkflowRun @@ -26,7 +26,11 @@ class TokenBufferMemory: self.model_instance = model_instance def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None + self, + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + *, + allowed_node_id: str | None = None, ) -> Sequence[PromptMessage]: """ Get history prompt messages. @@ -70,8 +74,45 @@ class TokenBufferMemory: messages = list(reversed(thread_messages)) prompt_messages: list[PromptMessage] = [] + + # Attachment filtering strategy + # • allowed_node_id is None → keep files used by ANY node (shared memory) + # • allowed_node_id == "" → strip ALL attachments + # • allowed_node_id == → keep only files this node used + + allowed_upload_ids: set[str] | None = None + + if allowed_node_id == "": + allowed_upload_ids = set() # keep none → strip all attachments + elif allowed_node_id is None: + # Shared memory: allow files used by any node + usage_rows = ( + db.session.query(NodeFileUsage.upload_file_id) + .filter(NodeFileUsage.conversation_id == self.conversation.id) + .all() + ) + allowed_upload_ids = {str(r[0]) for r in usage_rows} + else: + # Node-specific filtering + usage_rows = ( + db.session.query(NodeFileUsage.upload_file_id) + .filter( + NodeFileUsage.conversation_id == self.conversation.id, + NodeFileUsage.node_id == allowed_node_id, + ) + .all() + ) + allowed_upload_ids = {str(r[0]) for r in usage_rows} + for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + + # If attachment filtering is enabled, filter MessageFile list first. + if allowed_upload_ids is not None: + files = [ + f for f in files if f.upload_file_id and str(f.upload_file_id) in allowed_upload_ids + ] + if files: file_extra_config = None if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index ee1b1e1dd2..82130326ba 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -34,8 +34,7 @@ from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories.agent_factory import get_plugin_agent_strategy -from models.model import Conversation - +from models.model import Conversation, NodeFileUsage class AgentNode(ToolNode): """ @@ -257,35 +256,6 @@ class AgentNode(ToolNode): value = cast(dict[str, Any], value) model_instance, model_schema = self._fetch_model(value) history_prompt_messages: list[dict[str, Any]] = [] - if node_data.memory: - memory = self._fetch_memory(model_instance) - if memory: - prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size if node_data.memory.window.size else None - ) - history_prompt_messages = [ - prompt_message.model_dump(mode="json") for prompt_message in prompt_messages - ] - - # Strip file attachments from memory to prevent cross-turn leakage - def _strip_files_from_history(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - stripped: list[dict[str, Any]] = [] - for m in messages: - if m.get("role") == "user" and isinstance(m.get("content"), list): - # Keep only text content, remove file attachments - contents = [ - c - for c in m["content"] - if not (isinstance(c, dict) and c.get("type") and c.get("type") != "text") - ] - if not contents: - contents = [] - m = {**m, "content": contents} - stripped.append(m) - return stripped - - history_prompt_messages = _strip_files_from_history(history_prompt_messages) - # Check if this agent node references sys.files def _input_uses_sys_files(_agent_input): if _agent_input.type == "variable": @@ -298,9 +268,19 @@ class AgentNode(ToolNode): _input_uses_sys_files(inp) for inp in node_data.agent_parameters.values() ) - if uses_sys_files_for_node: - # History already stripped once above; ensure it's stripped (idempotent call) - history_prompt_messages = _strip_files_from_history(history_prompt_messages) + conv_var = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id_val = conv_var.value if isinstance(conv_var, StringSegment) else None + + if node_data.memory: + memory = self._fetch_memory(model_instance) + if memory: + prompt_messages = memory.get_history_prompt_messages( + message_limit=node_data.memory.window.size if node_data.memory.window.size else None, + allowed_node_id=None if uses_sys_files_for_node else "", + ) + history_prompt_messages = [ + prompt_message.model_dump(mode="json") for prompt_message in prompt_messages + ] if model_schema: # remove structured output feature to support old version agent plugin @@ -340,9 +320,41 @@ class AgentNode(ToolNode): continue # Append synthetic user message with current turn attachments - synthetic_user_prompt = UserPromptMessage(content=prompt_contents) + synthetic_user_prompt = UserPromptMessage( + content=prompt_contents, + name=f"__SYS_FILES__|{self.node_id}", # mark with node id for future filtering + ) history_prompt_messages.append(synthetic_user_prompt.model_dump(mode="json")) + # Persist usage to durable table + try: + if conversation_id_val: + with Session(db.engine) as session: + for f in files: + upload_id = getattr(f, "related_id", None) + if not upload_id: + continue + + # Upsert-ish: skip if record already exists + exists_stmt = select(NodeFileUsage.id).where( + NodeFileUsage.conversation_id == conversation_id_val, + NodeFileUsage.node_id == self.node_id, + NodeFileUsage.upload_file_id == upload_id, + ) + if not session.scalar(exists_stmt): + session.add( + NodeFileUsage( + conversation_id=conversation_id_val, + node_id=self.node_id, + upload_file_id=upload_id, + message_id=None, + ) + ) + session.commit() + except Exception: + # ignore if failed to persist + pass + value["history_prompt_messages"] = history_prompt_messages result[parameter_name] = value diff --git a/api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py b/api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py new file mode 100644 index 0000000000..f47e8fd10c --- /dev/null +++ b/api/migrations/versions/2025_07_07_1152-b2a0bfccd123_add_node_file_usage.py @@ -0,0 +1,36 @@ +"""add node_file_usage table + +Revision ID: b2a0bfccd123 +Revises: 0ab65e1cc7fa +Create Date: 2025-07-07 11:52:00 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b2a0bfccd123' +down_revision = '0ab65e1cc7fa' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "node_file_usage", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("uuid_generate_v4()")), + sa.Column("conversation_id", postgresql.UUID(as_uuid=True), nullable=False, index=True), + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("node_id", sa.String(length=64), nullable=False, index=True), + sa.Column("upload_file_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"],), + sa.ForeignKeyConstraint(["message_id"], ["messages.id"],), + sa.ForeignKeyConstraint(["upload_file_id"], ["upload_files.id"],), + ) + + +def downgrade() -> None: + op.drop_table("node_file_usage") \ No newline at end of file diff --git a/api/models/__init__.py b/api/models/__init__.py index 83b50eb099..cc9faeeefb 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -48,6 +48,7 @@ from .model import ( MessageChain, MessageFeedback, MessageFile, + NodeFileUsage, OperationLog, RecommendedApp, Site, @@ -138,6 +139,7 @@ __all__ = [ "MessageChain", "MessageFeedback", "MessageFile", + "NodeFileUsage", "OperationLog", "PinnedConversation", "Provider", diff --git a/api/models/model.py b/api/models/model.py index 93737043d5..9dbe6e8726 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1747,6 +1747,27 @@ class MessageAgentThought(Base): return {} +class NodeFileUsage(Base): + __tablename__ = "node_file_usage" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="node_file_usage_pkey"), + db.Index("node_file_usage_conversation_id_idx", "conversation_id"), + db.Index("node_file_usage_node_id_idx", "node_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + message_id = db.Column(StringUUID, db.ForeignKey("messages.id"), nullable=True) + node_id = db.Column(db.String(64), nullable=False) + upload_file_id = db.Column(StringUUID, db.ForeignKey("upload_files.id"), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + def __repr__(self) -> str: # pragma: no cover + return "".format( + cid=self.conversation_id, nid=self.node_id, fid=self.upload_file_id + ) + + class DatasetRetrieverResource(Base): __tablename__ = "dataset_retriever_resources" __table_args__ = (