refactor: Refactors Knowledge Retrieval Node for enhanced modularity

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 4f7f37f398
commit 2ef31de170
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,10 +1,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel): class RerankingModelConfig(BaseModel):
@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel):
weights: Optional[WeightedScoreConfig] = None weights: Optional[WeightedScoreConfig] = None
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class SingleRetrievalConfig(BaseModel): class SingleRetrievalConfig(BaseModel):
""" """
Single Retrieval Config. Single Retrieval Config.
@ -129,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None metadata_model_config: ModelConfig
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)

@ -4,7 +4,7 @@ import re
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
@ -15,20 +15,30 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import ModelFeature, ModelType PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
)
from core.workflow.nodes.knowledge_retrieval.template_prompts import ( from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1, METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2, METADATA_FILTER_ASSISTANT_PROMPT_2,
@ -38,7 +48,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3, METADATA_FILTER_USER_PROMPT_3,
) )
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -46,7 +57,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData, ModelConfig from .entities import KnowledgeRetrievalNodeData
from .exc import ( from .exc import (
InvalidModelTypeError, InvalidModelTypeError,
KnowledgeRetrievalNodeError, KnowledgeRetrievalNodeError,
@ -56,6 +67,10 @@ from .exc import (
ModelQuotaExceededError, ModelQuotaExceededError,
) )
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_retrieval_model = { default_retrieval_model = {
@ -67,10 +82,51 @@ default_retrieval_model = {
} }
class KnowledgeRetrievalNode(LLMNode): class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_type = NodeType.KNOWLEDGE_RETRIEVAL
node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = KnowledgeRetrievalNodeData(**data)
@classmethod @classmethod
def version(cls): def version(cls):
return "1" return "1"
@ -448,7 +504,7 @@ class KnowledgeRetrievalNode(LLMNode):
metadata_fields=all_metadata_fields, metadata_fields=all_metadata_fields,
query=query or "", query=query or "",
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
sys_query=query, sys_query=query,
memory=None, memory=None,
@ -458,16 +514,22 @@ class KnowledgeRetrievalNode(LLMNode):
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[], jinja2_variables=[],
tenant_id=self.tenant_id,
) )
result_text = "" result_text = ""
try: try:
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config, # type: ignore node_data_model=node_data.metadata_model_config,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
) )
for event in generator: for event in generator:

@ -1,4 +1,4 @@
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name. # We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")

@ -191,7 +191,10 @@ class LLMNode(BaseNode):
node_inputs["#context#"] = context node_inputs["#context#"] = context
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model) model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory # fetch memory
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
@ -209,7 +212,7 @@ class LLMNode(BaseNode):
): ):
query = query_variable.text query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query, sys_query=query,
sys_files=files, sys_files=files,
context=context, context=context,
@ -221,14 +224,20 @@ class LLMNode(BaseNode):
vision_detail=self.node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool, variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables, jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
) )
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model, node_data_model=self.node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
) )
structured_output: LLMStructuredOutput | None = None structured_output: LLMStructuredOutput | None = None
@ -298,12 +307,18 @@ class LLMNode(BaseNode):
) )
) )
def _invoke_llm( @staticmethod
self, def invoke_llm(
*,
node_data_model: ModelConfig, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
user_id: str,
structured_output_enabled: bool,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema( model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials node_data_model.name, model_instance.credentials
@ -311,8 +326,8 @@ class LLMNode(BaseNode):
if not model_schema: if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}") raise ValueError(f"Model schema not found for {node_data_model.name}")
if self.node_data.structured_output_enabled: if structured_output_enabled:
output_schema = self._fetch_structured_output_schema() output_schema = LLMNode.fetch_structured_output_schema()
invoke_result = invoke_llm_with_structured_output( invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider, provider=model_instance.provider,
model_schema=model_schema, model_schema=model_schema,
@ -322,7 +337,7 @@ class LLMNode(BaseNode):
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
stop=list(stop or []), stop=list(stop or []),
stream=True, stream=True,
user=self.user_id, user=user_id,
) )
else: else:
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
@ -330,17 +345,31 @@ class LLMNode(BaseNode):
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
stop=list(stop or []), stop=list(stop or []),
stream=True, stream=True,
user=self.user_id, user=user_id,
) )
return self._handle_invoke_result(invoke_result=invoke_result) return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
)
def _handle_invoke_result( @staticmethod
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode # For blocking mode
if isinstance(invoke_result, LLMResult): if isinstance(invoke_result, LLMResult):
event = self._handle_blocking_result(invoke_result=invoke_result) event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
)
yield event yield event
return return
@ -358,11 +387,13 @@ class LLMNode(BaseNode):
yield result yield result
if isinstance(result, LLMResultChunk): if isinstance(result, LLMResultChunk):
contents = result.delta.message.content contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
full_text_buffer.write(text_part) full_text_buffer.write(text_part)
yield RunStreamChunkEvent( yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
)
# Update the whole metadata # Update the whole metadata
if not model and result.model: if not model and result.model:
@ -380,7 +411,8 @@ class LLMNode(BaseNode):
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /): @staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})" text_chunk = f"![]({file.generate_url()})"
return text_chunk return text_chunk
@ -541,11 +573,14 @@ class LLMNode(BaseNode):
return None return None
@staticmethod
def _fetch_model_config( def _fetch_model_config(
self, node_data_model: ModelConfig *,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config( model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model tenant_id=tenant_id, node_data_model=node_data_model
) )
completion_params = model_config_with_cred.parameters completion_params = model_config_with_cred.parameters
@ -558,8 +593,8 @@ class LLMNode(BaseNode):
node_data_model.completion_params = completion_params node_data_model.completion_params = completion_params
return model, model_config_with_cred return model, model_config_with_cred
def _fetch_prompt_messages( @staticmethod
self, def fetch_prompt_messages(
*, *,
sys_query: str | None = None, sys_query: str | None = None,
sys_files: Sequence["File"], sys_files: Sequence["File"],
@ -572,13 +607,14 @@ class LLMNode(BaseNode):
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool, variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector], jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
# For chat model # For chat model
prompt_messages.extend( prompt_messages.extend(
self._handle_list_messages( LLMNode.handle_list_messages(
messages=prompt_template, messages=prompt_template,
context=context, context=context,
jinja2_variables=jinja2_variables, jinja2_variables=jinja2_variables,
@ -604,7 +640,7 @@ class LLMNode(BaseNode):
edition_type="basic", edition_type="basic",
) )
prompt_messages.extend( prompt_messages.extend(
self._handle_list_messages( LLMNode.handle_list_messages(
messages=[message], messages=[message],
context="", context="",
jinja2_variables=[], jinja2_variables=[],
@ -733,7 +769,7 @@ class LLMNode(BaseNode):
) )
model = ModelManager().get_model_instance( model = ModelManager().get_model_instance(
tenant_id=self.tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.provider, provider=model_config.provider,
model=model_config.model, model=model_config.model,
@ -837,8 +873,8 @@ class LLMNode(BaseNode):
}, },
} }
def _handle_list_messages( @staticmethod
self, def handle_list_messages(
*, *,
messages: Sequence[LLMNodeChatModelMessage], messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str], context: Optional[str],
@ -899,9 +935,19 @@ class LLMNode(BaseNode):
return prompt_messages return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: @staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO() buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part) buffer.write(text_part)
return ModelInvokeCompletedEvent( return ModelInvokeCompletedEvent(
@ -910,7 +956,12 @@ class LLMNode(BaseNode):
finish_reason=None, finish_reason=None,
) )
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": @staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins. """_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs: There are two kinds of multimodal outputs:
@ -920,19 +971,14 @@ class LLMNode(BaseNode):
Currently, only image files are supported. Currently, only image files are supported.
""" """
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "": if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else: else:
saved_file = _saver.save_binary_string( saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data), data=base64.b64decode(content.base64_data),
mime_type=content.mime_type, mime_type=content.mime_type,
file_type=FileType.IMAGE, file_type=FileType.IMAGE,
) )
self._file_outputs.append(saved_file)
return saved_file return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
@ -950,16 +996,20 @@ class LLMNode(BaseNode):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials) model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema return model_schema
def _fetch_structured_output_schema(self) -> dict[str, Any]: @staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
""" """
Fetch the structured output schema from the node data. Fetch the structured output schema from the node data.
Returns: Returns:
dict[str, Any]: The structured output schema dict[str, Any]: The structured output schema
""" """
if not self.node_data.structured_output: if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema") raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema: if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema") raise LLMNodeError("Please provide a valid structured output schema")
@ -971,9 +1021,12 @@ class LLMNode(BaseNode):
except json.JSONDecodeError: except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format") raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown( def _save_multimodal_output_and_convert_result_to_markdown(
self, *,
contents: str | list[PromptMessageContentUnionTypes] | None, contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller. """Convert intermediate prompt messages into strings and yield them to the caller.
@ -996,9 +1049,12 @@ class LLMNode(BaseNode):
if isinstance(item, TextPromptMessageContent): if isinstance(item, TextPromptMessageContent):
yield item.data yield item.data
elif isinstance(item, ImagePromptMessageContent): elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item) file = LLMNode.save_multimodal_image_output(
self._file_outputs.append(file) content=item,
yield self._image_file_to_markdown(file) file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else: else:
logger.warning("unknown item type encountered, type=%s", type(item)) logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item) yield str(item)

@ -1,6 +1,6 @@
import json import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -12,6 +12,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import ( from core.workflow.nodes.llm import (
@ -20,6 +21,7 @@ from core.workflow.nodes.llm import (
LLMNodeCompletionModelPromptTemplate, LLMNodeCompletionModelPromptTemplate,
llm_utils, llm_utils,
) )
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
@ -35,11 +37,53 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3, QUESTION_CLASSIFIER_USER_PROMPT_3,
) )
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData # type: ignore class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER _node_type = NodeType.QUESTION_CLASSIFIER
node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = QuestionClassifierNodeData(**data)
@classmethod @classmethod
def version(cls): def version(cls):
return "1" return "1"
@ -53,7 +97,10 @@ class QuestionClassifierNode(LLMNode):
query = variable.value if variable else None query = variable.value if variable else None
variables = {"query": query} variables = {"query": query}
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model) model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory # fetch memory
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
@ -91,7 +138,7 @@ class QuestionClassifierNode(LLMNode):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error. # two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
sys_query="", sys_query="",
memory=memory, memory=memory,
@ -101,6 +148,7 @@ class QuestionClassifierNode(LLMNode):
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool, variable_pool=variable_pool,
jinja2_variables=[], jinja2_variables=[],
tenant_id=self.tenant_id,
) )
result_text = "" result_text = ""
@ -109,11 +157,16 @@ class QuestionClassifierNode(LLMNode):
try: try:
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=node_data.model, node_data_model=node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id,
structured_output_enabled=False,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
) )
for event in generator: for event in generator:

@ -540,7 +540,10 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9, size=9,
) )
mock_file_saver.save_binary_string.return_value = mock_file mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with( mock_file_saver.save_binary_string.assert_called_once_with(
@ -566,7 +569,10 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9, size=9,
) )
mock_file_saver.save_remote_url.return_value = mock_file mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)

Loading…
Cancel
Save