feat: Add node file usage tracking and filtering

Introduces the NodeFileUsage model and migration to track file usage per node in conversations. Updates token buffer memory and agent node logic to filter file attachments based on node context, preventing cross-node file leakage. Also registers the new model in the models package.
pull/21938/head
Kalo Chin 11 months ago
parent 87530438c0
commit b448a2fccf

@ -16,7 +16,7 @@ from core.model_runtime.entities.message_entities import PromptMessageContentUni
from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile from models.model import AppMode, Conversation, Message, MessageFile, NodeFileUsage
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
@ -26,7 +26,11 @@ class TokenBufferMemory:
self.model_instance = model_instance self.model_instance = model_instance
def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None self,
max_token_limit: int = 2000,
message_limit: Optional[int] = None,
*,
allowed_node_id: str | None = None,
) -> Sequence[PromptMessage]: ) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
@ -70,8 +74,45 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
# Attachment filtering strategy
# • allowed_node_id is None → keep files used by ANY node (shared memory)
# • allowed_node_id == "" → strip ALL attachments
# • allowed_node_id == <node id> → keep only files this node used
allowed_upload_ids: set[str] | None = None
if allowed_node_id == "":
allowed_upload_ids = set() # keep none → strip all attachments
elif allowed_node_id is None:
# Shared memory: allow files used by any node
usage_rows = (
db.session.query(NodeFileUsage.upload_file_id)
.filter(NodeFileUsage.conversation_id == self.conversation.id)
.all()
)
allowed_upload_ids = {str(r[0]) for r in usage_rows}
else:
# Node-specific filtering
usage_rows = (
db.session.query(NodeFileUsage.upload_file_id)
.filter(
NodeFileUsage.conversation_id == self.conversation.id,
NodeFileUsage.node_id == allowed_node_id,
)
.all()
)
allowed_upload_ids = {str(r[0]) for r in usage_rows}
for message in messages: for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
# If attachment filtering is enabled, filter MessageFile list first.
if allowed_upload_ids is not None:
files = [
f for f in files if f.upload_file_id and str(f.upload_file_id) in allowed_upload_ids
]
if files: if files:
file_extra_config = None file_extra_config = None
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:

@ -34,8 +34,7 @@ from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from factories.agent_factory import get_plugin_agent_strategy from factories.agent_factory import get_plugin_agent_strategy
from models.model import Conversation from models.model import Conversation, NodeFileUsage
class AgentNode(ToolNode): class AgentNode(ToolNode):
""" """
@ -257,35 +256,6 @@ class AgentNode(ToolNode):
value = cast(dict[str, Any], value) value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value) model_instance, model_schema = self._fetch_model(value)
history_prompt_messages: list[dict[str, Any]] = [] history_prompt_messages: list[dict[str, Any]] = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
# Strip file attachments from memory to prevent cross-turn leakage
def _strip_files_from_history(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
stripped: list[dict[str, Any]] = []
for m in messages:
if m.get("role") == "user" and isinstance(m.get("content"), list):
# Keep only text content, remove file attachments
contents = [
c
for c in m["content"]
if not (isinstance(c, dict) and c.get("type") and c.get("type") != "text")
]
if not contents:
contents = []
m = {**m, "content": contents}
stripped.append(m)
return stripped
history_prompt_messages = _strip_files_from_history(history_prompt_messages)
# Check if this agent node references sys.files # Check if this agent node references sys.files
def _input_uses_sys_files(_agent_input): def _input_uses_sys_files(_agent_input):
if _agent_input.type == "variable": if _agent_input.type == "variable":
@ -298,9 +268,19 @@ class AgentNode(ToolNode):
_input_uses_sys_files(inp) for inp in node_data.agent_parameters.values() _input_uses_sys_files(inp) for inp in node_data.agent_parameters.values()
) )
if uses_sys_files_for_node: conv_var = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
# History already stripped once above; ensure it's stripped (idempotent call) conversation_id_val = conv_var.value if isinstance(conv_var, StringSegment) else None
history_prompt_messages = _strip_files_from_history(history_prompt_messages)
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None,
allowed_node_id=None if uses_sys_files_for_node else "",
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
if model_schema: if model_schema:
# remove structured output feature to support old version agent plugin # remove structured output feature to support old version agent plugin
@ -340,9 +320,41 @@ class AgentNode(ToolNode):
continue continue
# Append synthetic user message with current turn attachments # Append synthetic user message with current turn attachments
synthetic_user_prompt = UserPromptMessage(content=prompt_contents) synthetic_user_prompt = UserPromptMessage(
content=prompt_contents,
name=f"__SYS_FILES__|{self.node_id}", # mark with node id for future filtering
)
history_prompt_messages.append(synthetic_user_prompt.model_dump(mode="json")) history_prompt_messages.append(synthetic_user_prompt.model_dump(mode="json"))
# Persist usage to durable table
try:
if conversation_id_val:
with Session(db.engine) as session:
for f in files:
upload_id = getattr(f, "related_id", None)
if not upload_id:
continue
# Upsert-ish: skip if record already exists
exists_stmt = select(NodeFileUsage.id).where(
NodeFileUsage.conversation_id == conversation_id_val,
NodeFileUsage.node_id == self.node_id,
NodeFileUsage.upload_file_id == upload_id,
)
if not session.scalar(exists_stmt):
session.add(
NodeFileUsage(
conversation_id=conversation_id_val,
node_id=self.node_id,
upload_file_id=upload_id,
message_id=None,
)
)
session.commit()
except Exception:
# ignore if failed to persist
pass
value["history_prompt_messages"] = history_prompt_messages value["history_prompt_messages"] = history_prompt_messages
result[parameter_name] = value result[parameter_name] = value

@ -0,0 +1,36 @@
"""add node_file_usage table
Revision ID: b2a0bfccd123
Revises: 0ab65e1cc7fa
Create Date: 2025-07-07 11:52:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'b2a0bfccd123'
down_revision = '0ab65e1cc7fa'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"node_file_usage",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("uuid_generate_v4()")),
sa.Column("conversation_id", postgresql.UUID(as_uuid=True), nullable=False, index=True),
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("node_id", sa.String(length=64), nullable=False, index=True),
sa.Column("upload_file_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["conversation_id"], ["conversations.id"],),
sa.ForeignKeyConstraint(["message_id"], ["messages.id"],),
sa.ForeignKeyConstraint(["upload_file_id"], ["upload_files.id"],),
)
def downgrade() -> None:
op.drop_table("node_file_usage")

@ -48,6 +48,7 @@ from .model import (
MessageChain, MessageChain,
MessageFeedback, MessageFeedback,
MessageFile, MessageFile,
NodeFileUsage,
OperationLog, OperationLog,
RecommendedApp, RecommendedApp,
Site, Site,
@ -138,6 +139,7 @@ __all__ = [
"MessageChain", "MessageChain",
"MessageFeedback", "MessageFeedback",
"MessageFile", "MessageFile",
"NodeFileUsage",
"OperationLog", "OperationLog",
"PinnedConversation", "PinnedConversation",
"Provider", "Provider",

@ -1747,6 +1747,27 @@ class MessageAgentThought(Base):
return {} return {}
class NodeFileUsage(Base):
__tablename__ = "node_file_usage"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="node_file_usage_pkey"),
db.Index("node_file_usage_conversation_id_idx", "conversation_id"),
db.Index("node_file_usage_node_id_idx", "node_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
message_id = db.Column(StringUUID, db.ForeignKey("messages.id"), nullable=True)
node_id = db.Column(db.String(64), nullable=False)
upload_file_id = db.Column(StringUUID, db.ForeignKey("upload_files.id"), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self) -> str: # pragma: no cover
return "<NodeFileUsage conversation_id={cid} node_id={nid} upload_file_id={fid}>".format(
cid=self.conversation_id, nid=self.node_id, fid=self.upload_file_id
)
class DatasetRetrieverResource(Base): class DatasetRetrieverResource(Base):
__tablename__ = "dataset_retriever_resources" __tablename__ = "dataset_retriever_resources"
__table_args__ = ( __table_args__ = (

Loading…
Cancel
Save