diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 9bfb402dc8..756cbf2cee 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -69,6 +69,12 @@ from core.workflow.nodes.event import ( ) from core.workflow.utils.variable_template_parser import VariableTemplateParser +# --- New imports for file-usage persistence --- +from sqlalchemy.orm import Session +from sqlalchemy import select +from extensions.ext_database import db +from models.model import NodeFileUsage + from . import llm_utils from .entities import ( LLMNodeChatModelMessage, @@ -175,6 +181,46 @@ class LLMNode(BaseNode[LLMNodeData]): else [] ) + # ------------------------------ + # Persist current-turn file usage so that subsequent turns/nodes can see them. + # A node is considered "using" sys.files when vision is enabled and a variable selector is configured. + # ------------------------------ + + uses_sys_files_for_node = self.node_data.vision.enabled and bool( + self.node_data.vision.configs.variable_selector + ) + + if uses_sys_files_for_node and files: + conv_var = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id_val = conv_var.value if isinstance(conv_var, StringSegment) else None + + if conversation_id_val: + try: + with Session(db.engine) as session: + for f in files: + upload_id = getattr(f, "related_id", None) + if not upload_id: + continue + + exists_stmt = select(NodeFileUsage.id).where( + NodeFileUsage.conversation_id == conversation_id_val, + NodeFileUsage.node_id == self.node_id, + NodeFileUsage.upload_file_id == upload_id, + ) + if not session.scalar(exists_stmt): + session.add( + NodeFileUsage( + conversation_id=conversation_id_val, + node_id=self.node_id, + upload_file_id=upload_id, + message_id=None, + ) + ) + session.commit() + except Exception: + # Non-critical – failure to persist should not break node execution + pass + if files: node_inputs["#files#"] = [file.to_dict() for file in files] @@ -219,6 +265,7 @@ class LLMNode(BaseNode[LLMNodeData]): vision_detail=self.node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, + uses_sys_files_for_node=uses_sys_files_for_node, ) # handle invoke result @@ -570,6 +617,7 @@ class LLMNode(BaseNode[LLMNodeData]): vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], + uses_sys_files_for_node: bool, ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages: list[PromptMessage] = [] @@ -590,6 +638,7 @@ class LLMNode(BaseNode[LLMNodeData]): memory=memory, memory_config=memory_config, model_config=model_config, + allowed_node_id=None if uses_sys_files_for_node else "", ) # Extend prompt_messages with memory messages prompt_messages.extend(memory_messages) @@ -1075,6 +1124,7 @@ def _handle_memory_chat_mode( memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, + allowed_node_id: str | None, ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model @@ -1083,6 +1133,7 @@ def _handle_memory_chat_mode( memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, + allowed_node_id=allowed_node_id, ) return memory_messages