From 2ef31de1703ee5adb3e405cfd3277bc0be00ffdf Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 15 Jul 2025 17:24:55 +0800 Subject: [PATCH] refactor: Refactors Knowledge Retrieval Node for enhanced modularity Signed-off-by: -LAN- --- .../nodes/knowledge_retrieval/entities.py | 17 +-- .../knowledge_retrieval_node.py | 86 +++++++++-- api/core/workflow/nodes/llm/entities.py | 4 +- api/core/workflow/nodes/llm/node.py | 144 ++++++++++++------ .../question_classifier_node.py | 65 +++++++- .../core/workflow/nodes/llm/test_node.py | 10 +- 6 files changed, 246 insertions(+), 80 deletions(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 19bdee4fe2..e9122b1eec 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,10 +1,10 @@ from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Literal, Optional from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import VisionConfig +from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): @@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel): weights: Optional[WeightedScoreConfig] = None -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} - - class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. @@ -129,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None + metadata_model_config: ModelConfig metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index f05d93d83e..ef617cc878 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast @@ -15,20 +15,30 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.entities.message_entities import ( + PromptMessageRole, +) +from core.model_runtime.entities.model_entities import ( + ModelFeature, + ModelType, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables import StringSegment +from core.variables import ( + StringSegment, +) from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, +) from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_ASSISTANT_PROMPT_1, METADATA_FILTER_ASSISTANT_PROMPT_2, @@ -38,7 +48,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -46,7 +57,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from services.feature_service import FeatureService -from .entities import KnowledgeRetrievalNodeData, ModelConfig +from .entities import KnowledgeRetrievalNodeData from .exc import ( InvalidModelTypeError, KnowledgeRetrievalNodeError, @@ -56,6 +67,10 @@ from .exc import ( ModelQuotaExceededError, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + logger = logging.getLogger(__name__) default_retrieval_model = { @@ -67,10 +82,51 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(LLMNode): - _node_data_cls = KnowledgeRetrievalNodeData # type: ignore +class KnowledgeRetrievalNode(BaseNode): _node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_data: KnowledgeRetrievalNodeData + + # Instance attributes specific to LLMNode. + # Output variable for file + _file_outputs: list["File"] + + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = KnowledgeRetrievalNodeData(**data) + @classmethod def version(cls): return "1" @@ -448,7 +504,7 @@ class KnowledgeRetrievalNode(LLMNode): metadata_fields=all_metadata_fields, query=query or "", ) - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query=query, memory=None, @@ -458,16 +514,22 @@ class KnowledgeRetrievalNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=self.graph_runtime_state.variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" try: # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.metadata_model_config, # type: ignore + generator = LLMNode.invoke_llm( + node_data_model=node_data.metadata_model_config, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self.node_data.structured_output_enabled, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 36d0688807..4bb62d35a2 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: dict | None = None + structured_output: Mapping[str, Any] | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index cd2d4a7970..f57b401a7f 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -191,7 +191,10 @@ class LLMNode(BaseNode): node_inputs["#context#"] = context # fetch model config - model_instance, model_config = self._fetch_model_config(self.node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=self.node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( @@ -209,7 +212,7 @@ class LLMNode(BaseNode): ): query = query_variable.text - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, context=context, @@ -221,14 +224,20 @@ class LLMNode(BaseNode): vision_detail=self.node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, + tenant_id=self.tenant_id, ) # handle invoke result - generator = self._invoke_llm( + generator = LLMNode.invoke_llm( node_data_model=self.node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self.node_data.structured_output_enabled, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) structured_output: LLMStructuredOutput | None = None @@ -298,12 +307,18 @@ class LLMNode(BaseNode): ) ) - def _invoke_llm( - self, + @staticmethod + def invoke_llm( + *, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, + user_id: str, + structured_output_enabled: bool, + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: model_schema = model_instance.model_type_instance.get_model_schema( node_data_model.name, model_instance.credentials @@ -311,8 +326,8 @@ class LLMNode(BaseNode): if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") - if self.node_data.structured_output_enabled: - output_schema = self._fetch_structured_output_schema() + if structured_output_enabled: + output_schema = LLMNode.fetch_structured_output_schema() invoke_result = invoke_llm_with_structured_output( provider=model_instance.provider, model_schema=model_schema, @@ -322,7 +337,7 @@ class LLMNode(BaseNode): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) else: invoke_result = model_instance.invoke_llm( @@ -330,17 +345,31 @@ class LLMNode(BaseNode): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) - return self._handle_invoke_result(invoke_result=invoke_result) + return LLMNode.handle_invoke_result( + invoke_result=invoke_result, + file_saver=file_saver, + file_outputs=file_outputs, + node_id=node_id, + ) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] + @staticmethod + def handle_invoke_result( + *, + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): - event = self._handle_blocking_result(invoke_result=invoke_result) + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=file_saver, + file_outputs=file_outputs, + ) yield event return @@ -358,11 +387,13 @@ class LLMNode(BaseNode): yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=contents, + file_saver=file_saver, + file_outputs=file_outputs, + ): full_text_buffer.write(text_part) - yield RunStreamChunkEvent( - chunk_content=text_part, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"]) # Update the whole metadata if not model and result.model: @@ -380,7 +411,8 @@ class LLMNode(BaseNode): yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) - def _image_file_to_markdown(self, file: "File", /): + @staticmethod + def _image_file_to_markdown(file: "File", /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -541,11 +573,14 @@ class LLMNode(BaseNode): return None + @staticmethod def _fetch_model_config( - self, node_data_model: ModelConfig + *, + node_data_model: ModelConfig, + tenant_id: str, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model, model_config_with_cred = llm_utils.fetch_model_config( - tenant_id=self.tenant_id, node_data_model=node_data_model + tenant_id=tenant_id, node_data_model=node_data_model ) completion_params = model_config_with_cred.parameters @@ -558,8 +593,8 @@ class LLMNode(BaseNode): node_data_model.completion_params = completion_params return model, model_config_with_cred - def _fetch_prompt_messages( - self, + @staticmethod + def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence["File"], @@ -572,13 +607,14 @@ class LLMNode(BaseNode): vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], + tenant_id: str, ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -604,7 +640,7 @@ class LLMNode(BaseNode): edition_type="basic", ) prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -733,7 +769,7 @@ class LLMNode(BaseNode): ) model = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model, @@ -837,8 +873,8 @@ class LLMNode(BaseNode): }, } - def _handle_list_messages( - self, + @staticmethod + def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str], @@ -899,9 +935,19 @@ class LLMNode(BaseNode): return prompt_messages - def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + @staticmethod + def handle_blocking_result( + *, + invoke_result: LLMResult, + saver: LLMFileSaver, + file_outputs: list["File"], + ) -> ModelInvokeCompletedEvent: buffer = io.StringIO() - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=invoke_result.message.content, + file_saver=saver, + file_outputs=file_outputs, + ): buffer.write(text_part) return ModelInvokeCompletedEvent( @@ -910,7 +956,12 @@ class LLMNode(BaseNode): finish_reason=None, ) - def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": + @staticmethod + def save_multimodal_image_output( + *, + content: ImagePromptMessageContent, + file_saver: LLMFileSaver, + ) -> "File": """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -920,19 +971,14 @@ class LLMNode(BaseNode): Currently, only image files are supported. """ - # Inject the saver somehow... - _saver = self._llm_file_saver - - # If this if content.url != "": - saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) + saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) else: - saved_file = _saver.save_binary_string( + saved_file = file_saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, ) - self._file_outputs.append(saved_file) return saved_file def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: @@ -950,16 +996,20 @@ class LLMNode(BaseNode): model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_schema - def _fetch_structured_output_schema(self) -> dict[str, Any]: + @staticmethod + def fetch_structured_output_schema( + *, + structured_output: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: """ Fetch the structured output schema from the node data. Returns: dict[str, Any]: The structured output schema """ - if not self.node_data.structured_output: + if not structured_output: raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) if not structured_output_schema: raise LLMNodeError("Please provide a valid structured output schema") @@ -971,9 +1021,12 @@ class LLMNode(BaseNode): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") + @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( - self, + *, contents: str | list[PromptMessageContentUnionTypes] | None, + file_saver: LLMFileSaver, + file_outputs: list["File"], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. @@ -996,9 +1049,12 @@ class LLMNode(BaseNode): if isinstance(item, TextPromptMessageContent): yield item.data elif isinstance(item, ImagePromptMessageContent): - file = self._save_multimodal_image_output(item) - self._file_outputs.append(file) - yield self._image_file_to_markdown(file) + file = LLMNode.save_multimodal_image_output( + content=item, + file_saver=file_saver, + ) + file_outputs.append(file) + yield LLMNode._image_file_to_markdown(file) else: logger.warning("unknown item type encountered, type=%s", type(item)) yield str(item) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 74024ed90c..93ea7c9ca4 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -12,6 +12,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( @@ -20,6 +21,7 @@ from core.workflow.nodes.llm import ( LLMNodeCompletionModelPromptTemplate, llm_utils, ) +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown @@ -35,11 +37,53 @@ from .template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_3, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState -class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData # type: ignore + +class QuestionClassifierNode(BaseNode): _node_type = NodeType.QUESTION_CLASSIFIER + node_data: QuestionClassifierNodeData + + _file_outputs: list["File"] + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = QuestionClassifierNodeData(**data) + @classmethod def version(cls): return "1" @@ -53,7 +97,10 @@ class QuestionClassifierNode(LLMNode): query = variable.value if variable else None variables = {"query": query} # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -91,7 +138,7 @@ class QuestionClassifierNode(LLMNode): # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # two consecutive user prompts will be generated, causing model's error. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", memory=memory, @@ -101,6 +148,7 @@ class QuestionClassifierNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" @@ -109,11 +157,16 @@ class QuestionClassifierNode(LLMNode): try: # handle invoke result - generator = self._invoke_llm( + generator = LLMNode.invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=False, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fefad0ec95..42f3bffab1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -540,7 +540,10 @@ class TestLLMNodeSaveMultiModalImageOutput: size=9, ) mock_file_saver.save_binary_string.return_value = mock_file - file = llm_node._save_multimodal_image_output(content=content) + file = llm_node.save_multimodal_image_output( + content=content, + file_saver=mock_file_saver, + ) assert llm_node._file_outputs == [mock_file] assert file == mock_file mock_file_saver.save_binary_string.assert_called_once_with( @@ -566,7 +569,10 @@ class TestLLMNodeSaveMultiModalImageOutput: size=9, ) mock_file_saver.save_remote_url.return_value = mock_file - file = llm_node._save_multimodal_image_output(content=content) + file = llm_node.save_multimodal_image_output( + content=content, + file_saver=mock_file_saver, + ) assert llm_node._file_outputs == [mock_file] assert file == mock_file mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)