@ -7,12 +7,12 @@ from datetime import UTC, datetime
from typing import TYPE_CHECKING , Any , Optional , cast
import json_repair
from sqlalchemy import select , update
from sqlalchemy . orm import Session
from configs import dify_config
from core . app . entities . app_invoke_entities import ModelConfigWithCredentialsEntity
from core . entities . model_entities import ModelStatus
from core . entities . provider_entities import QuotaUnit
from core . errors . error import ModelCurrentlyNotSupportError , ProviderTokenNotInitError , QuotaExceededError
from core . file import FileType , file_manager
from core . helper . code_executor import CodeExecutor , CodeLanguage
from core . memory . token_buffer_memory import TokenBufferMemory
@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core . plugin . entities . plugin import ModelProviderID
from core . prompt . entities . advanced_prompt_entities import CompletionModelPromptTemplate , MemoryConfig
from core . prompt . utils . prompt_message_util import PromptMessageUtil
from core . rag . entities . citation_metadata import RetrievalSourceMetadata
from core . variables import (
ArrayAnySegment ,
ArrayFileSegment ,
@ -53,9 +54,10 @@ from core.variables import (
StringSegment ,
)
from core . workflow . constants import SYSTEM_VARIABLE_NODE_ID
from core . workflow . entities . node_entities import NodeRun MetadataKey, NodeRun Result
from core . workflow . entities . node_entities import NodeRun Result
from core . workflow . entities . variable_entities import VariableSelector
from core . workflow . entities . variable_pool import VariablePool
from core . workflow . entities . workflow_node_execution import WorkflowNodeExecutionMetadataKey , WorkflowNodeExecutionStatus
from core . workflow . enums import SystemVariableKey
from core . workflow . graph_engine . entities . event import InNodeEvent
from core . workflow . nodes . base import BaseNode
@ -70,14 +72,12 @@ from core.workflow.nodes.event import (
from core . workflow . utils . structured_output . entities import (
ResponseFormat ,
SpecialModelType ,
SupportStructuredOutputStatus ,
)
from core . workflow . utils . structured_output . prompt import STRUCTURED_OUTPUT_PROMPT
from core . workflow . utils . variable_template_parser import VariableTemplateParser
from extensions . ext_database import db
from models . model import Conversation
from models . provider import Provider , ProviderType
from models . workflow import WorkflowNodeExecutionStatus
from . entities import (
LLMNodeChatModelMessage ,
@ -267,14 +267,14 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data = process_data ,
outputs = outputs ,
metadata = {
NodeRu nMetadataKey. TOTAL_TOKENS : usage . total_tokens ,
NodeRu nMetadataKey. TOTAL_PRICE : usage . total_price ,
NodeRu nMetadataKey. CURRENCY : usage . currency ,
WorkflowNodeExecutio nMetadataKey. TOTAL_TOKENS : usage . total_tokens ,
WorkflowNodeExecutio nMetadataKey. TOTAL_PRICE : usage . total_price ,
WorkflowNodeExecutio nMetadataKey. CURRENCY : usage . currency ,
} ,
llm_usage = usage ,
)
)
except LLMNod eError as e :
except Valu eError as e :
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
@ -302,8 +302,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages : Sequence [ PromptMessage ] ,
stop : Optional [ Sequence [ str ] ] = None ,
) - > Generator [ NodeEvent , None , None ] :
db . session . close ( )
invoke_result = model_instance . invoke_llm (
prompt_messages = list ( prompt_messages ) ,
model_parameters = node_data_model . completion_params ,
@ -474,7 +472,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent ( retriever_resources = [ ] , context = context_value_variable . value )
elif isinstance ( context_value_variable , ArraySegment ) :
context_str = " "
original_retriever_resource = [ ]
original_retriever_resource : list [ RetrievalSourceMetadata ] = [ ]
for item in context_value_variable . value :
if isinstance ( item , str ) :
context_str + = item + " \n "
@ -492,7 +490,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources = original_retriever_resource , context = context_str . strip ( )
)
def _convert_to_original_retriever_resource ( self , context_dict : dict ) - > Optional [ dict ] :
def _convert_to_original_retriever_resource ( self , context_dict : dict ) :
if (
" metadata " in context_dict
and " _source " in context_dict [ " metadata " ]
@ -500,24 +498,24 @@ class LLMNode(BaseNode[LLMNodeData]):
) :
metadata = context_dict . get ( " metadata " , { } )
source = {
" position " : metadata . get ( " position " ) ,
" dataset_id " : metadata . get ( " dataset_id " ) ,
" dataset_name " : metadata . get ( " dataset_name " ) ,
" document_id " : metadata . get ( " document_id " ) ,
" document_name " : metadata . get ( " document_name " ) ,
" data_source_type " : metadata . get ( " data_source_type " ) ,
" segment_id " : metadata . get ( " segment_id " ) ,
" retriever_from " : metadata . get ( " retriever_from " ) ,
" score " : metadata . get ( " score " ) ,
" hit_count " : metadata . get ( " segment_hit_count " ) ,
" word_count " : metadata . get ( " segment_word_count " ) ,
" segment_position " : metadata . get ( " segment_position " ) ,
" index_node_hash " : metadata . get ( " segment_index_node_hash " ) ,
" content " : context_dict . get ( " content " ) ,
" page " : metadata . get ( " page " ) ,
" doc_metadata " : metadata . get ( " doc_metadata " ) ,
}
source = RetrievalSourceMetadata (
position = metadata . get ( " position " ) ,
dataset_id = metadata . get ( " dataset_id " ) ,
dataset_name = metadata . get ( " dataset_name " ) ,
document_id = metadata . get ( " document_id " ) ,
document_name = metadata . get ( " document_name " ) ,
data_source_type = metadata . get ( " data_source_type " ) ,
segment_id = metadata . get ( " segment_id " ) ,
retriever_from = metadata . get ( " retriever_from " ) ,
score = metadata . get ( " score " ) ,
hit_count = metadata . get ( " segment_hit_count " ) ,
word_count = metadata . get ( " segment_word_count " ) ,
segment_position = metadata . get ( " segment_position " ) ,
index_node_hash = metadata . get ( " segment_index_node_hash " ) ,
content = context_dict . get ( " content " ) ,
page = metadata . get ( " page " ) ,
doc_metadata = metadata . get ( " doc_metadata " ) ,
)
return source
@ -526,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config (
self , node_data_model : ModelConfig
) - > tuple [ ModelInstance , ModelConfigWithCredentialsEntity ] :
model_name = node_data_model . name
provider_name = node_data_model . provider
if not node_data_model . mode :
raise LLMModeRequiredError ( " LLM mode is required. " )
model_manager = ModelManager ( )
model_instance = model_manager . get_model_instance (
tenant_id = self . tenant_id , model_type = ModelType . LLM , provider = provider_name , model = model_name
model = ModelManager ( ) . get_model_instance (
tenant_id = self . tenant_id ,
model_type = ModelType . LLM ,
provider = node_data_model . provider ,
model = node_data_model . name ,
)
provider_model_bundle = model_instance . provider_model_bundle
model_type_instance = model_instance . model_type_instance
model_type_instance = cast ( LargeLanguageModel , model_type_instance )
model_credentials = model_instance . credentials
model . model_type_instance = cast ( LargeLanguageModel , model . model_type_instance )
# check model
provider_model = provider_model_bundle. configuration . get_provider_model (
model = model_ name, model_type = ModelType . LLM
provider_model = model . provider_model_bundle . configuration . get_provider_model (
model = node_data_model . name , model_type = ModelType . LLM
)
if provider_model is None :
raise ModelNotExistError ( f " Model { model_name } not exist. " )
if provider_model . status == ModelStatus . NO_CONFIGURE :
raise ProviderTokenNotInitError ( f " Model { model_name } credentials is not initialized. " )
elif provider_model . status == ModelStatus . NO_PERMISSION :
raise ModelCurrentlyNotSupportError ( f " Dify Hosted OpenAI { model_name } currently not support. " )
elif provider_model . status == ModelStatus . QUOTA_EXCEEDED :
raise QuotaExceededError ( f " Model provider { provider_name } quota exceeded. " )
raise ModelNotExistError ( f " Model { node_data_model . name } not exist. " )
provider_model . raise_for_status ( )
# model config
completion_params = node_data_model . completion_params
stop = [ ]
if " stop " in completion_params :
stop = completion_params [ " stop " ]
del completion_params [ " stop " ]
# get model mode
model_mode = node_data_model . mode
if not model_mode :
raise LLMModeRequiredError ( " LLM mode is required. " )
model_schema = model_type_instance . get_model_schema ( model_name , model_credentials )
stop : list [ str ] = [ ]
if " stop " in node_data_model . completion_params :
stop = node_data_model . completion_params . pop ( " stop " )
model_schema = model . model_type_instance . get_model_schema ( node_data_model . name , model . credentials )
if not model_schema :
raise ModelNotExistError ( f " Model { model_name } not exist. " )
support_structured_output = self . _check_model_structured_output_support ( )
if support_structured_output == SupportStructuredOutputStatus . SUPPORTED :
completion_params = self . _handle_native_json_schema ( completion_params , model_schema . parameter_rules )
elif support_structured_output == SupportStructuredOutputStatus . UNSUPPORTED :
# Set appropriate response format based on model capabilities
self . _set_response_format ( completion_params , model_schema . parameter_rules )
return model_instance , ModelConfigWithCredentialsEntity (
provider = provider_name ,
model = model_name ,
raise ModelNotExistError ( f " Model { node_data_model . name } not exist. " )
if self . node_data . structured_output_enabled :
if model_schema . support_structure_output :
node_data_model . completion_params = self . _handle_native_json_schema (
node_data_model . completion_params , model_schema . parameter_rules
)
else :
# Set appropriate response format based on model capabilities
self . _set_response_format ( node_data_model . completion_params , model_schema . parameter_rules )
return model , ModelConfigWithCredentialsEntity (
provider = node_data_model . provider ,
model = node_data_model . name ,
model_schema = model_schema ,
mode = model_mode ,
provider_model_bundle = provider_model_bundle ,
credentials = model_credentials ,
parameters = completion_params,
mode = node_data_model. mode,
provider_model_bundle = model. provider_model_bundle,
credentials = model . credentials,
parameters = node_data_model. completion_params,
stop = stop ,
)
@ -602,15 +588,11 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
conversation_id = conversation_id_variable . value
# get conversation
conversation = (
db . session . query ( Conversation )
. filter ( Conversation . app_id == self . app_id , Conversation . id == conversation_id )
. first ( )
)
if not conversation :
return None
with Session ( db . engine , expire_on_commit = False ) as session :
stmt = select ( Conversation ) . where ( Conversation . app_id == self . app_id , Conversation . id == conversation_id )
conversation = session . scalar ( stmt )
if not conversation :
return None
memory = TokenBufferMemory ( conversation = conversation , model_instance = model_instance )
@ -789,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
" No prompt found in the LLM configuration. "
" Please ensure a prompt is properly configured before proceeding. "
)
support_structured_output = self . _check_model_structured_output_support ( )
if support_structured_output == SupportStructuredOutputStatus . UNSUPPORTED :
filtered_prompt_messages = self . _handle_prompt_based_schema (
prompt_messages = filtered_prompt_messages ,
)
stop = model_config . stop
return filtered_prompt_messages , stop
model = ModelManager ( ) . get_model_instance (
tenant_id = self . tenant_id ,
model_type = ModelType . LLM ,
provider = self . node_data . model . provider ,
model = self . node_data . model . name ,
)
model_schema = model . model_type_instance . get_model_schema (
model = self . node_data . model . name ,
credentials = model . credentials ,
)
if not model_schema :
raise ModelNotExistError ( f " Model { self . node_data . model . name } not exist. " )
if self . node_data . structured_output_enabled :
if not model_schema . support_structure_output :
filtered_prompt_messages = self . _handle_prompt_based_schema (
prompt_messages = filtered_prompt_messages ,
)
return filtered_prompt_messages , model_config . stop
def _parse_structured_output ( self , result_text : str ) - > dict [ str , Any ] :
structured_output : dict [ str , Any ] = { }
@ -846,20 +840,24 @@ class LLMNode(BaseNode[LLMNodeData]):
used_quota = 1
if used_quota is not None and system_configuration . current_quota_type is not None :
db . session . query ( Provider ) . filter (
Provider . tenant_id == tenant_id ,
# TODO: Use provider name with prefix after the data migration.
Provider . provider_name == ModelProviderID ( model_instance . provider ) . provider_name ,
Provider . provider_type == ProviderType . SYSTEM . value ,
Provider . quota_type == system_configuration . current_quota_type . value ,
Provider . quota_limit > Provider . quota_used ,
) . update (
{
" quota_used " : Provider . quota_used + used_quota ,
" last_used " : datetime . now ( tz = UTC ) . replace ( tzinfo = None ) ,
}
)
db . session . commit ( )
with Session ( db . engine ) as session :
stmt = (
update ( Provider )
. where (
Provider . tenant_id == tenant_id ,
# TODO: Use provider name with prefix after the data migration.
Provider . provider_name == ModelProviderID ( model_instance . provider ) . provider_name ,
Provider . provider_type == ProviderType . SYSTEM . value ,
Provider . quota_type == system_configuration . current_quota_type . value ,
Provider . quota_limit > Provider . quota_used ,
)
. values (
quota_used = Provider . quota_used + used_quota ,
last_used = datetime . now ( tz = UTC ) . replace ( tzinfo = None ) ,
)
)
session . execute ( stmt )
session . commit ( )
@classmethod
def _extract_variable_selector_to_variable_mapping (
@ -902,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_mapping [ " #context# " ] = node_data . context . variable_selector
if node_data . vision . enabled :
variable_mapping [ " #files# " ] = [ " sys " , SystemVariableKey . FILES . value ]
variable_mapping [ " #files# " ] = node_data . vision . configs . variable_selector
if node_data . memory :
variable_mapping [ " #sys.query# " ] = [ " sys " , SystemVariableKey . QUERY . value ]
@ -1184,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json . JSONDecodeError :
raise LLMNodeError ( " structured_output_schema is not valid JSON format " )
def _check_model_structured_output_support ( self ) - > SupportStructuredOutputStatus :
"""
Check if the current model supports structured output .
Returns :
SupportStructuredOutput : The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance ( self . node_data , LLMNodeData )
or not self . node_data . structured_output_enabled
or not self . node_data . structured_output
) :
return SupportStructuredOutputStatus . DISABLED
# Get model schema and check if it exists
model_schema = self . _fetch_model_schema ( self . node_data . model . provider )
if not model_schema :
return SupportStructuredOutputStatus . DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus . SUPPORTED
if bool ( model_schema . features and ModelFeature . STRUCTURED_OUTPUT in model_schema . features )
else SupportStructuredOutputStatus . UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown (
self ,
contents : str | list [ PromptMessageContentUnionTypes ] | None ,