|
|
|
|
@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence
|
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
import json_repair
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
|
|
|
from core.entities.model_entities import ModelStatus
|
|
|
|
|
@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
|
|
|
|
|
SystemPromptMessage,
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
|
|
|
|
from core.model_runtime.entities.model_entities import (
|
|
|
|
|
AIModelEntity,
|
|
|
|
|
ModelFeature,
|
|
|
|
|
ModelPropertyKey,
|
|
|
|
|
ModelType,
|
|
|
|
|
ParameterRule,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
|
|
from core.plugin.entities.plugin import ModelProviderID
|
|
|
|
|
@ -57,6 +65,12 @@ from core.workflow.nodes.event import (
|
|
|
|
|
RunRetrieverResourceEvent,
|
|
|
|
|
RunStreamChunkEvent,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.utils.structured_output.entities import (
|
|
|
|
|
ResponseFormat,
|
|
|
|
|
SpecialModelType,
|
|
|
|
|
SupportStructuredOutputStatus,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
|
|
|
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from models.model import Conversation
|
|
|
|
|
@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
_node_type = NodeType.LLM
|
|
|
|
|
|
|
|
|
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
|
|
|
|
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
|
|
|
|
"""Process structured output if enabled"""
|
|
|
|
|
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
|
|
|
|
return None
|
|
|
|
|
return self._parse_structured_output(text)
|
|
|
|
|
|
|
|
|
|
node_inputs: Optional[dict[str, Any]] = None
|
|
|
|
|
process_data = None
|
|
|
|
|
result_text = ""
|
|
|
|
|
@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
if isinstance(event, RunRetrieverResourceEvent):
|
|
|
|
|
context = event.context
|
|
|
|
|
yield event
|
|
|
|
|
|
|
|
|
|
if context:
|
|
|
|
|
node_inputs["#context#"] = context
|
|
|
|
|
|
|
|
|
|
@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
|
|
|
|
break
|
|
|
|
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
|
|
|
|
|
|
|
|
|
structured_output = process_structured_output(result_text)
|
|
|
|
|
if structured_output:
|
|
|
|
|
outputs["structured_output"] = structured_output
|
|
|
|
|
yield RunCompletedEvent(
|
|
|
|
|
run_result=NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
|
|
|
@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
|
|
|
|
|
if not model_schema:
|
|
|
|
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
|
|
|
|
|
|
|
|
|
support_structured_output = self._check_model_structured_output_support()
|
|
|
|
|
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
|
|
|
|
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
|
|
|
|
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
|
|
|
|
# Set appropriate response format based on model capabilities
|
|
|
|
|
self._set_response_format(completion_params, model_schema.parameter_rules)
|
|
|
|
|
return model_instance, ModelConfigWithCredentialsEntity(
|
|
|
|
|
provider=provider_name,
|
|
|
|
|
model=model_name,
|
|
|
|
|
@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
"No prompt found in the LLM configuration. "
|
|
|
|
|
"Please ensure a prompt is properly configured before proceeding."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
support_structured_output = self._check_model_structured_output_support()
|
|
|
|
|
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
|
|
|
|
filtered_prompt_messages = self._handle_prompt_based_schema(
|
|
|
|
|
prompt_messages=filtered_prompt_messages,
|
|
|
|
|
)
|
|
|
|
|
stop = model_config.stop
|
|
|
|
|
return filtered_prompt_messages, stop
|
|
|
|
|
|
|
|
|
|
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
|
|
|
|
structured_output: dict[str, Any] | list[Any] = {}
|
|
|
|
|
try:
|
|
|
|
|
parsed = json.loads(result_text)
|
|
|
|
|
if not isinstance(parsed, (dict | list)):
|
|
|
|
|
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
|
|
|
|
structured_output = parsed
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
# if the result_text is not a valid json, try to repair it
|
|
|
|
|
parsed = json_repair.loads(result_text)
|
|
|
|
|
if not isinstance(parsed, (dict | list)):
|
|
|
|
|
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
|
|
|
|
structured_output = parsed
|
|
|
|
|
return structured_output
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
|
|
|
|
provider_model_bundle = model_instance.provider_model_bundle
|
|
|
|
|
@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
|
|
|
|
|
return prompt_messages
|
|
|
|
|
|
|
|
|
|
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Handle structured output for models with native JSON schema support.
|
|
|
|
|
|
|
|
|
|
:param model_parameters: Model parameters to update
|
|
|
|
|
:param rules: Model parameter rules
|
|
|
|
|
:return: Updated model parameters with JSON schema configuration
|
|
|
|
|
"""
|
|
|
|
|
# Process schema according to model requirements
|
|
|
|
|
schema = self._fetch_structured_output_schema()
|
|
|
|
|
schema_json = self._prepare_schema_for_model(schema)
|
|
|
|
|
|
|
|
|
|
# Set JSON schema in parameters
|
|
|
|
|
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
# Set appropriate response format if required by the model
|
|
|
|
|
for rule in rules:
|
|
|
|
|
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
|
|
|
|
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
|
|
|
|
|
|
|
|
|
return model_parameters
|
|
|
|
|
|
|
|
|
|
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
|
|
|
|
"""
|
|
|
|
|
Handle structured output for models without native JSON schema support.
|
|
|
|
|
This function modifies the prompt messages to include schema-based output requirements.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt_messages: Original sequence of prompt messages
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list[PromptMessage]: Updated prompt messages with structured output requirements
|
|
|
|
|
"""
|
|
|
|
|
# Convert schema to string format
|
|
|
|
|
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
# Find existing system prompt with schema placeholder
|
|
|
|
|
system_prompt = next(
|
|
|
|
|
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
|
|
|
|
None,
|
|
|
|
|
)
|
|
|
|
|
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
|
|
|
|
# Prepare system prompt content
|
|
|
|
|
system_prompt_content = (
|
|
|
|
|
structured_output_prompt + "\n\n" + system_prompt.content
|
|
|
|
|
if system_prompt and isinstance(system_prompt.content, str)
|
|
|
|
|
else structured_output_prompt
|
|
|
|
|
)
|
|
|
|
|
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
|
|
|
|
|
|
|
|
|
# Extract content from the last user message
|
|
|
|
|
|
|
|
|
|
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
|
|
|
|
updated_prompt = [system_prompt] + filtered_prompts
|
|
|
|
|
|
|
|
|
|
return updated_prompt
|
|
|
|
|
|
|
|
|
|
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Set the appropriate response format parameter based on model rules.
|
|
|
|
|
|
|
|
|
|
:param model_parameters: Model parameters to update
|
|
|
|
|
:param rules: Model parameter rules
|
|
|
|
|
"""
|
|
|
|
|
for rule in rules:
|
|
|
|
|
if rule.name == "response_format":
|
|
|
|
|
if ResponseFormat.JSON.value in rule.options:
|
|
|
|
|
model_parameters["response_format"] = ResponseFormat.JSON.value
|
|
|
|
|
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
|
|
|
|
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
|
|
|
|
|
|
|
|
|
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Prepare JSON schema based on model requirements.
|
|
|
|
|
|
|
|
|
|
Different models have different requirements for JSON schema formatting.
|
|
|
|
|
This function handles these differences.
|
|
|
|
|
|
|
|
|
|
:param schema: The original JSON schema
|
|
|
|
|
:return: Processed schema compatible with the current model
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Deep copy to avoid modifying the original schema
|
|
|
|
|
processed_schema = schema.copy()
|
|
|
|
|
|
|
|
|
|
# Convert boolean types to string types (common requirement)
|
|
|
|
|
convert_boolean_to_string(processed_schema)
|
|
|
|
|
|
|
|
|
|
# Apply model-specific transformations
|
|
|
|
|
if SpecialModelType.GEMINI in self.node_data.model.name:
|
|
|
|
|
remove_additional_properties(processed_schema)
|
|
|
|
|
return processed_schema
|
|
|
|
|
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
|
|
|
|
return processed_schema
|
|
|
|
|
else:
|
|
|
|
|
# Default format with name field
|
|
|
|
|
return {"schema": processed_schema, "name": "llm_response"}
|
|
|
|
|
|
|
|
|
|
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
|
|
|
|
"""
|
|
|
|
|
Fetch model schema
|
|
|
|
|
"""
|
|
|
|
|
model_name = self.node_data.model.name
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
model_instance = model_manager.get_model_instance(
|
|
|
|
|
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
|
|
|
|
)
|
|
|
|
|
model_type_instance = model_instance.model_type_instance
|
|
|
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
|
|
model_credentials = model_instance.credentials
|
|
|
|
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
|
|
|
|
return model_schema
|
|
|
|
|
|
|
|
|
|
def _fetch_structured_output_schema(self) -> dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Fetch the structured output schema from the node data.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict[str, Any]: The structured output schema
|
|
|
|
|
"""
|
|
|
|
|
if not self.node_data.structured_output:
|
|
|
|
|
raise LLMNodeError("Please provide a valid structured output schema")
|
|
|
|
|
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
|
|
|
|
if not structured_output_schema:
|
|
|
|
|
raise LLMNodeError("Please provide a valid structured output schema")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
schema = json.loads(structured_output_schema)
|
|
|
|
|
if not isinstance(schema, dict):
|
|
|
|
|
raise LLMNodeError("structured_output_schema must be a JSON object")
|
|
|
|
|
return schema
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
|
|
|
|
|
|
|
|
|
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
|
|
|
|
"""
|
|
|
|
|
Check if the current model supports structured output.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
SupportStructuredOutput: The support status of structured output
|
|
|
|
|
"""
|
|
|
|
|
# Early return if structured output is disabled
|
|
|
|
|
if (
|
|
|
|
|
not isinstance(self.node_data, LLMNodeData)
|
|
|
|
|
or not self.node_data.structured_output_enabled
|
|
|
|
|
or not self.node_data.structured_output
|
|
|
|
|
):
|
|
|
|
|
return SupportStructuredOutputStatus.DISABLED
|
|
|
|
|
# Get model schema and check if it exists
|
|
|
|
|
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
|
|
|
|
if not model_schema:
|
|
|
|
|
return SupportStructuredOutputStatus.DISABLED
|
|
|
|
|
|
|
|
|
|
# Check if model supports structured output feature
|
|
|
|
|
return (
|
|
|
|
|
SupportStructuredOutputStatus.SUPPORTED
|
|
|
|
|
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
|
|
|
|
else SupportStructuredOutputStatus.UNSUPPORTED
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
|
|
|
|
match role:
|
|
|
|
|
@ -1064,3 +1269,49 @@ def _handle_completion_template(
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(prompt_message)
|
|
|
|
|
return prompt_messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_additional_properties(schema: dict) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Remove additionalProperties fields from JSON schema.
|
|
|
|
|
Used for models like Gemini that don't support this property.
|
|
|
|
|
|
|
|
|
|
:param schema: JSON schema to modify in-place
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(schema, dict):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Remove additionalProperties at current level
|
|
|
|
|
schema.pop("additionalProperties", None)
|
|
|
|
|
|
|
|
|
|
# Process nested structures recursively
|
|
|
|
|
for value in schema.values():
|
|
|
|
|
if isinstance(value, dict):
|
|
|
|
|
remove_additional_properties(value)
|
|
|
|
|
elif isinstance(value, list):
|
|
|
|
|
for item in value:
|
|
|
|
|
if isinstance(item, dict):
|
|
|
|
|
remove_additional_properties(item)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_boolean_to_string(schema: dict) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Convert boolean type specifications to string in JSON schema.
|
|
|
|
|
|
|
|
|
|
:param schema: JSON schema to modify in-place
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(schema, dict):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Check for boolean type at current level
|
|
|
|
|
if schema.get("type") == "boolean":
|
|
|
|
|
schema["type"] = "string"
|
|
|
|
|
|
|
|
|
|
# Process nested dictionaries and lists recursively
|
|
|
|
|
for value in schema.values():
|
|
|
|
|
if isinstance(value, dict):
|
|
|
|
|
convert_boolean_to_string(value)
|
|
|
|
|
elif isinstance(value, list):
|
|
|
|
|
for item in value:
|
|
|
|
|
if isinstance(item, dict):
|
|
|
|
|
convert_boolean_to_string(item)
|
|
|
|
|
|