refactor: Refactors Knowledge Retrieval Node for enhanced modularity

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 4f7f37f398
commit 2ef31de170
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,10 +1,10 @@
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing import Literal, Optional
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel):
weights: Optional[WeightedScoreConfig] = None
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
@ -129,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_model_config: ModelConfig
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
@ -15,20 +15,30 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
)
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
@ -38,7 +48,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -46,7 +57,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData, ModelConfig
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
@ -56,6 +67,10 @@ from .exc import (
ModelQuotaExceededError,
)
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
default_retrieval_model = {
@ -67,10 +82,51 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(LLMNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = KnowledgeRetrievalNodeData(**data)
@classmethod
def version(cls):
return "1"
@ -448,7 +504,7 @@ class KnowledgeRetrievalNode(LLMNode):
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
@ -458,16 +514,22 @@ class KnowledgeRetrievalNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.metadata_model_config, # type: ignore
generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
for event in generator:

@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")

@ -191,7 +191,10 @@ class LLMNode(BaseNode):
node_inputs["#context#"] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model)
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory
memory = llm_utils.fetch_memory(
@ -209,7 +212,7 @@ class LLMNode(BaseNode):
):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
@ -221,14 +224,20 @@ class LLMNode(BaseNode):
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
)
# handle invoke result
generator = self._invoke_llm(
generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
structured_output: LLMStructuredOutput | None = None
@ -298,12 +307,18 @@ class LLMNode(BaseNode):
)
)
def _invoke_llm(
self,
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
user_id: str,
structured_output_enabled: bool,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@ -311,8 +326,8 @@ class LLMNode(BaseNode):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if self.node_data.structured_output_enabled:
output_schema = self._fetch_structured_output_schema()
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema()
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
@ -322,7 +337,7 @@ class LLMNode(BaseNode):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=self.user_id,
user=user_id,
)
else:
invoke_result = model_instance.invoke_llm(
@ -330,17 +345,31 @@ class LLMNode(BaseNode):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=self.user_id,
user=user_id,
)
return self._handle_invoke_result(invoke_result=invoke_result)
return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
)
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
@staticmethod
def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
event = self._handle_blocking_result(invoke_result=invoke_result)
event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
)
yield event
return
@ -358,11 +387,13 @@ class LLMNode(BaseNode):
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
# Update the whole metadata
if not model and result.model:
@ -380,7 +411,8 @@ class LLMNode(BaseNode):
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /):
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@ -541,11 +573,14 @@ class LLMNode(BaseNode):
return None
@staticmethod
def _fetch_model_config(
self, node_data_model: ModelConfig
*,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
tenant_id=tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
@ -558,8 +593,8 @@ class LLMNode(BaseNode):
node_data_model.completion_params = completion_params
return model, model_config_with_cred
def _fetch_prompt_messages(
self,
@staticmethod
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
@ -572,13 +607,14 @@ class LLMNode(BaseNode):
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
self._handle_list_messages(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@ -604,7 +640,7 @@ class LLMNode(BaseNode):
edition_type="basic",
)
prompt_messages.extend(
self._handle_list_messages(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@ -733,7 +769,7 @@ class LLMNode(BaseNode):
)
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
@ -837,8 +873,8 @@ class LLMNode(BaseNode):
},
}
def _handle_list_messages(
self,
@staticmethod
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
@ -899,9 +935,19 @@ class LLMNode(BaseNode):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
@staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
@ -910,7 +956,12 @@ class LLMNode(BaseNode):
finish_reason=None,
)
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
@staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@ -920,19 +971,14 @@ class LLMNode(BaseNode):
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = _saver.save_binary_string(
saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
self._file_outputs.append(saved_file)
return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
@ -950,16 +996,20 @@ class LLMNode(BaseNode):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
def _fetch_structured_output_schema(self) -> dict[str, Any]:
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
if not self.node_data.structured_output:
if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
@ -971,9 +1021,12 @@ class LLMNode(BaseNode):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
self,
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
@ -996,9 +1049,12 @@ class LLMNode(BaseNode):
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item)
self._file_outputs.append(file)
yield self._image_file_to_markdown(file)
file = LLMNode.save_multimodal_image_output(
content=item,
file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -12,6 +12,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
@ -20,6 +21,7 @@ from core.workflow.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -35,11 +37,53 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData # type: ignore
class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = QuestionClassifierNodeData(**data)
@classmethod
def version(cls):
return "1"
@ -53,7 +97,10 @@ class QuestionClassifierNode(LLMNode):
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@ -91,7 +138,7 @@ class QuestionClassifierNode(LLMNode):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@ -101,6 +148,7 @@ class QuestionClassifierNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
@ -109,11 +157,16 @@ class QuestionClassifierNode(LLMNode):
try:
# handle invoke result
generator = self._invoke_llm(
generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=False,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
for event in generator:

@ -540,7 +540,10 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
@ -566,7 +569,10 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)

Loading…
Cancel
Save