@ -3,16 +3,11 @@ import io
import json
import json
import logging
import logging
from collections . abc import Generator , Mapping , Sequence
from collections . abc import Generator , Mapping , Sequence
from datetime import UTC , datetime
from typing import TYPE_CHECKING , Any , Optional , cast
from typing import TYPE_CHECKING , Any , Optional , cast
import json_repair
import json_repair
from configs import dify_config
from core . app . entities . app_invoke_entities import ModelConfigWithCredentialsEntity
from core . app . entities . app_invoke_entities import ModelConfigWithCredentialsEntity
from core . entities . model_entities import ModelStatus
from core . entities . provider_entities import QuotaUnit
from core . errors . error import ModelCurrentlyNotSupportError , ProviderTokenNotInitError , QuotaExceededError
from core . file import FileType , file_manager
from core . file import FileType , file_manager
from core . helper . code_executor import CodeExecutor , CodeLanguage
from core . helper . code_executor import CodeExecutor , CodeLanguage
from core . memory . token_buffer_memory import TokenBufferMemory
from core . memory . token_buffer_memory import TokenBufferMemory
@ -40,11 +35,10 @@ from core.model_runtime.entities.model_entities import (
)
)
from core . model_runtime . model_providers . __base . large_language_model import LargeLanguageModel
from core . model_runtime . model_providers . __base . large_language_model import LargeLanguageModel
from core . model_runtime . utils . encoders import jsonable_encoder
from core . model_runtime . utils . encoders import jsonable_encoder
from core . plugin . entities . plugin import ModelProviderID
from core . prompt . entities . advanced_prompt_entities import CompletionModelPromptTemplate , MemoryConfig
from core . prompt . entities . advanced_prompt_entities import CompletionModelPromptTemplate , MemoryConfig
from core . prompt . utils . prompt_message_util import PromptMessageUtil
from core . prompt . utils . prompt_message_util import PromptMessageUtil
from core . rag . entities . citation_metadata import RetrievalSourceMetadata
from core . variables import (
from core . variables import (
ArrayAnySegment ,
ArrayFileSegment ,
ArrayFileSegment ,
ArraySegment ,
ArraySegment ,
FileSegment ,
FileSegment ,
@ -71,14 +65,11 @@ from core.workflow.nodes.event import (
from core . workflow . utils . structured_output . entities import (
from core . workflow . utils . structured_output . entities import (
ResponseFormat ,
ResponseFormat ,
SpecialModelType ,
SpecialModelType ,
SupportStructuredOutputStatus ,
)
)
from core . workflow . utils . structured_output . prompt import STRUCTURED_OUTPUT_PROMPT
from core . workflow . utils . structured_output . prompt import STRUCTURED_OUTPUT_PROMPT
from core . workflow . utils . variable_template_parser import VariableTemplateParser
from core . workflow . utils . variable_template_parser import VariableTemplateParser
from extensions . ext_database import db
from models . model import Conversation
from models . provider import Provider , ProviderType
from . import llm_utils
from . entities import (
from . entities import (
LLMNodeChatModelMessage ,
LLMNodeChatModelMessage ,
LLMNodeCompletionModelPromptTemplate ,
LLMNodeCompletionModelPromptTemplate ,
@ -88,7 +79,6 @@ from .entities import (
from . exc import (
from . exc import (
InvalidContextStructureError ,
InvalidContextStructureError ,
InvalidVariableTypeError ,
InvalidVariableTypeError ,
LLMModeRequiredError ,
LLMNodeError ,
LLMNodeError ,
MemoryRolePrefixRequiredError ,
MemoryRolePrefixRequiredError ,
ModelNotExistError ,
ModelNotExistError ,
@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]):
result_text = " "
result_text = " "
usage = LLMUsage . empty_usage ( )
usage = LLMUsage . empty_usage ( )
finish_reason = None
finish_reason = None
variable_pool = self . graph_runtime_state . variable_pool
try :
try :
# init messages template
# init messages template
@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch files
# fetch files
files = (
files = (
self . _fetch_files ( selector = self . node_data . vision . configs . variable_selector )
llm_utils . fetch_files (
variable_pool = variable_pool ,
selector = self . node_data . vision . configs . variable_selector ,
)
if self . node_data . vision . enabled
if self . node_data . vision . enabled
else [ ]
else [ ]
)
)
@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]):
model_instance , model_config = self . _fetch_model_config ( self . node_data . model )
model_instance , model_config = self . _fetch_model_config ( self . node_data . model )
# fetch memory
# fetch memory
memory = self . _fetch_memory ( node_data_memory = self . node_data . memory , model_instance = model_instance )
memory = llm_utils . fetch_memory (
variable_pool = variable_pool ,
app_id = self . app_id ,
node_data_memory = self . node_data . memory ,
model_instance = model_instance ,
)
query = None
query = None
if self . node_data . memory :
if self . node_data . memory :
query = self . node_data . memory . query_prompt_template
query = self . node_data . memory . query_prompt_template
if not query and (
if not query and (
query_variable := self . graph_runtime_state . variable_pool . get (
query_variable := variable_pool . get ( ( SYSTEM_VARIABLE_NODE_ID , SystemVariableKey . QUERY ) )
( SYSTEM_VARIABLE_NODE_ID , SystemVariableKey . QUERY )
)
) :
) :
query = query_variable . text
query = query_variable . text
@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config = self . node_data . memory ,
memory_config = self . node_data . memory ,
vision_enabled = self . node_data . vision . enabled ,
vision_enabled = self . node_data . vision . enabled ,
vision_detail = self . node_data . vision . configs . detail ,
vision_detail = self . node_data . vision . configs . detail ,
variable_pool = self . graph_runtime_state . variable_pool ,
variable_pool = variable_pool ,
jinja2_variables = self . node_data . prompt_config . jinja2_variables ,
jinja2_variables = self . node_data . prompt_config . jinja2_variables ,
)
)
@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event . usage
usage = event . usage
finish_reason = event . finish_reason
finish_reason = event . finish_reason
# deduct quota
# deduct quota
self . deduct_llm_quota ( tenant_id = self . tenant_id , model_instance = model_instance , usage = usage )
llm_utils . deduct_llm_quota ( tenant_id = self . tenant_id , model_instance = model_instance , usage = usage )
break
break
outputs = { " text " : result_text , " usage " : jsonable_encoder ( usage ) , " finish_reason " : finish_reason }
outputs = { " text " : result_text , " usage " : jsonable_encoder ( usage ) , " finish_reason " : finish_reason }
structured_output = process_structured_output ( result_text )
structured_output = process_structured_output ( result_text )
@ -274,7 +271,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage = usage ,
llm_usage = usage ,
)
)
)
)
except LLMNod eError as e :
except Valu eError as e :
yield RunCompletedEvent (
yield RunCompletedEvent (
run_result = NodeRunResult (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
status = WorkflowNodeExecutionStatus . FAILED ,
@ -302,8 +299,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages : Sequence [ PromptMessage ] ,
prompt_messages : Sequence [ PromptMessage ] ,
stop : Optional [ Sequence [ str ] ] = None ,
stop : Optional [ Sequence [ str ] ] = None ,
) - > Generator [ NodeEvent , None , None ] :
) - > Generator [ NodeEvent , None , None ] :
db . session . close ( )
invoke_result = model_instance . invoke_llm (
invoke_result = model_instance . invoke_llm (
prompt_messages = list ( prompt_messages ) ,
prompt_messages = list ( prompt_messages ) ,
model_parameters = node_data_model . completion_params ,
model_parameters = node_data_model . completion_params ,
@ -449,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]):
return inputs
return inputs
def _fetch_files ( self , * , selector : Sequence [ str ] ) - > Sequence [ " File " ] :
variable = self . graph_runtime_state . variable_pool . get ( selector )
if variable is None :
return [ ]
elif isinstance ( variable , FileSegment ) :
return [ variable . value ]
elif isinstance ( variable , ArrayFileSegment ) :
return variable . value
elif isinstance ( variable , NoneSegment | ArrayAnySegment ) :
return [ ]
raise InvalidVariableTypeError ( f " Invalid variable type: { type ( variable ) } " )
def _fetch_context ( self , node_data : LLMNodeData ) :
def _fetch_context ( self , node_data : LLMNodeData ) :
if not node_data . context . enabled :
if not node_data . context . enabled :
return
return
@ -474,7 +457,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent ( retriever_resources = [ ] , context = context_value_variable . value )
yield RunRetrieverResourceEvent ( retriever_resources = [ ] , context = context_value_variable . value )
elif isinstance ( context_value_variable , ArraySegment ) :
elif isinstance ( context_value_variable , ArraySegment ) :
context_str = " "
context_str = " "
original_retriever_resource = [ ]
original_retriever_resource : list [ RetrievalSourceMetadata ] = [ ]
for item in context_value_variable . value :
for item in context_value_variable . value :
if isinstance ( item , str ) :
if isinstance ( item , str ) :
context_str + = item + " \n "
context_str + = item + " \n "
@ -492,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources = original_retriever_resource , context = context_str . strip ( )
retriever_resources = original_retriever_resource , context = context_str . strip ( )
)
)
def _convert_to_original_retriever_resource ( self , context_dict : dict ) - > Optional [ dict ] :
def _convert_to_original_retriever_resource ( self , context_dict : dict ) :
if (
if (
" metadata " in context_dict
" metadata " in context_dict
and " _source " in context_dict [ " metadata " ]
and " _source " in context_dict [ " metadata " ]
@ -500,24 +483,24 @@ class LLMNode(BaseNode[LLMNodeData]):
) :
) :
metadata = context_dict . get ( " metadata " , { } )
metadata = context_dict . get ( " metadata " , { } )
source = {
source = RetrievalSourceMetadata (
" position " : metadata . get ( " position " ) ,
position = metadata . get ( " position " ) ,
" dataset_id " : metadata . get ( " dataset_id " ) ,
dataset_id = metadata . get ( " dataset_id " ) ,
" dataset_name " : metadata . get ( " dataset_name " ) ,
dataset_name = metadata . get ( " dataset_name " ) ,
" document_id " : metadata . get ( " document_id " ) ,
document_id = metadata . get ( " document_id " ) ,
" document_name " : metadata . get ( " document_name " ) ,
document_name = metadata . get ( " document_name " ) ,
" data_source_type " : metadata . get ( " data_source_type " ) ,
data_source_type = metadata . get ( " data_source_type " ) ,
" segment_id " : metadata . get ( " segment_id " ) ,
segment_id = metadata . get ( " segment_id " ) ,
" retriever_from " : metadata . get ( " retriever_from " ) ,
retriever_from = metadata . get ( " retriever_from " ) ,
" score " : metadata . get ( " score " ) ,
score = metadata . get ( " score " ) ,
" hit_count " : metadata . get ( " segment_hit_count " ) ,
hit_count = metadata . get ( " segment_hit_count " ) ,
" word_count " : metadata . get ( " segment_word_count " ) ,
word_count = metadata . get ( " segment_word_count " ) ,
" segment_position " : metadata . get ( " segment_position " ) ,
segment_position = metadata . get ( " segment_position " ) ,
" index_node_hash " : metadata . get ( " segment_index_node_hash " ) ,
index_node_hash = metadata . get ( " segment_index_node_hash " ) ,
" content " : context_dict . get ( " content " ) ,
content = context_dict . get ( " content " ) ,
" page " : metadata . get ( " page " ) ,
page = metadata . get ( " page " ) ,
" doc_metadata " : metadata . get ( " doc_metadata " ) ,
doc_metadata = metadata . get ( " doc_metadata " ) ,
}
)
return source
return source
@ -526,95 +509,25 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config (
def _fetch_model_config (
self , node_data_model : ModelConfig
self , node_data_model : ModelConfig
) - > tuple [ ModelInstance , ModelConfigWithCredentialsEntity ] :
) - > tuple [ ModelInstance , ModelConfigWithCredentialsEntity ] :
model_name = node_data_model . name
model , model_config_with_cred = llm_utils . fetch_model_config (
provider_name = node_data_model . provider
tenant_id = self . tenant_id , node_data_model = node_data_model
model_manager = ModelManager ( )
model_instance = model_manager . get_model_instance (
tenant_id = self . tenant_id , model_type = ModelType . LLM , provider = provider_name , model = model_name
)
provider_model_bundle = model_instance . provider_model_bundle
model_type_instance = model_instance . model_type_instance
model_type_instance = cast ( LargeLanguageModel , model_type_instance )
model_credentials = model_instance . credentials
# check model
provider_model = provider_model_bundle . configuration . get_provider_model (
model = model_name , model_type = ModelType . LLM
)
)
completion_params = model_config_with_cred . parameters
if provider_model is None :
model_schema = model . model_type_instance . get_model_schema ( node_data_model . name , model . credentials )
raise ModelNotExistError ( f " Model { model_name } not exist. " )
if provider_model . status == ModelStatus . NO_CONFIGURE :
raise ProviderTokenNotInitError ( f " Model { model_name } credentials is not initialized. " )
elif provider_model . status == ModelStatus . NO_PERMISSION :
raise ModelCurrentlyNotSupportError ( f " Dify Hosted OpenAI { model_name } currently not support. " )
elif provider_model . status == ModelStatus . QUOTA_EXCEEDED :
raise QuotaExceededError ( f " Model provider { provider_name } quota exceeded. " )
# model config
completion_params = node_data_model . completion_params
stop = [ ]
if " stop " in completion_params :
stop = completion_params [ " stop " ]
del completion_params [ " stop " ]
# get model mode
model_mode = node_data_model . mode
if not model_mode :
raise LLMModeRequiredError ( " LLM mode is required. " )
model_schema = model_type_instance . get_model_schema ( model_name , model_credentials )
if not model_schema :
if not model_schema :
raise ModelNotExistError ( f " Model { model_name } not exist. " )
raise ModelNotExistError ( f " Model { node_data_model . name } not exist. " )
support_structured_output = self . _check_model_structured_output_support ( )
if support_structured_output == SupportStructuredOutputStatus . SUPPORTED :
if self . node_data . structured_output_enabled :
if model_schema . support_structure_output :
completion_params = self . _handle_native_json_schema ( completion_params , model_schema . parameter_rules )
completion_params = self . _handle_native_json_schema ( completion_params , model_schema . parameter_rules )
el if support_structur ed_output == SupportStructuredOutputStatus . UNSUPPORTED :
else :
# Set appropriate response format based on model capabilities
# Set appropriate response format based on model capabilities
self . _set_response_format ( completion_params , model_schema . parameter_rules )
self . _set_response_format ( completion_params , model_schema . parameter_rules )
return model_instance , ModelConfigWithCredentialsEntity (
model_config_with_cred . parameters = completion_params
provider = provider_name ,
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
model = model_name ,
node_data_model . completion_params = completion_params
model_schema = model_schema ,
return model , model_config_with_cred
mode = model_mode ,
provider_model_bundle = provider_model_bundle ,
credentials = model_credentials ,
parameters = completion_params ,
stop = stop ,
)
def _fetch_memory (
self , node_data_memory : Optional [ MemoryConfig ] , model_instance : ModelInstance
) - > Optional [ TokenBufferMemory ] :
if not node_data_memory :
return None
# get conversation id
conversation_id_variable = self . graph_runtime_state . variable_pool . get (
[ " sys " , SystemVariableKey . CONVERSATION_ID . value ]
)
if not isinstance ( conversation_id_variable , StringSegment ) :
return None
conversation_id = conversation_id_variable . value
# get conversation
conversation = (
db . session . query ( Conversation )
. filter ( Conversation . app_id == self . app_id , Conversation . id == conversation_id )
. first ( )
)
if not conversation :
return None
memory = TokenBufferMemory ( conversation = conversation , model_instance = model_instance )
return memory
def _fetch_prompt_messages (
def _fetch_prompt_messages (
self ,
self ,
@ -789,13 +702,25 @@ class LLMNode(BaseNode[LLMNodeData]):
" No prompt found in the LLM configuration. "
" No prompt found in the LLM configuration. "
" Please ensure a prompt is properly configured before proceeding. "
" Please ensure a prompt is properly configured before proceeding. "
)
)
support_structured_output = self . _check_model_structured_output_support ( )
if support_structured_output == SupportStructuredOutputStatus . UNSUPPORTED :
model = ModelManager ( ) . get_model_instance (
tenant_id = self . tenant_id ,
model_type = ModelType . LLM ,
provider = model_config . provider ,
model = model_config . model ,
)
model_schema = model . model_type_instance . get_model_schema (
model = model_config . model ,
credentials = model . credentials ,
)
if not model_schema :
raise ModelNotExistError ( f " Model { model_config . model } not exist. " )
if self . node_data . structured_output_enabled :
if not model_schema . support_structure_output :
filtered_prompt_messages = self . _handle_prompt_based_schema (
filtered_prompt_messages = self . _handle_prompt_based_schema (
prompt_messages = filtered_prompt_messages ,
prompt_messages = filtered_prompt_messages ,
)
)
stop = model_config . stop
return filtered_prompt_messages , model_config . stop
return filtered_prompt_messages , stop
def _parse_structured_output ( self , result_text : str ) - > dict [ str , Any ] :
def _parse_structured_output ( self , result_text : str ) - > dict [ str , Any ] :
structured_output : dict [ str , Any ] = { }
structured_output : dict [ str , Any ] = { }
@ -816,51 +741,6 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = parsed
structured_output = parsed
return structured_output
return structured_output
@classmethod
def deduct_llm_quota ( cls , tenant_id : str , model_instance : ModelInstance , usage : LLMUsage ) - > None :
provider_model_bundle = model_instance . provider_model_bundle
provider_configuration = provider_model_bundle . configuration
if provider_configuration . using_provider_type != ProviderType . SYSTEM :
return
system_configuration = provider_configuration . system_configuration
quota_unit = None
for quota_configuration in system_configuration . quota_configurations :
if quota_configuration . quota_type == system_configuration . current_quota_type :
quota_unit = quota_configuration . quota_unit
if quota_configuration . quota_limit == - 1 :
return
break
used_quota = None
if quota_unit :
if quota_unit == QuotaUnit . TOKENS :
used_quota = usage . total_tokens
elif quota_unit == QuotaUnit . CREDITS :
used_quota = dify_config . get_model_credits ( model_instance . model )
else :
used_quota = 1
if used_quota is not None and system_configuration . current_quota_type is not None :
db . session . query ( Provider ) . filter (
Provider . tenant_id == tenant_id ,
# TODO: Use provider name with prefix after the data migration.
Provider . provider_name == ModelProviderID ( model_instance . provider ) . provider_name ,
Provider . provider_type == ProviderType . SYSTEM . value ,
Provider . quota_type == system_configuration . current_quota_type . value ,
Provider . quota_limit > Provider . quota_used ,
) . update (
{
" quota_used " : Provider . quota_used + used_quota ,
" last_used " : datetime . now ( tz = UTC ) . replace ( tzinfo = None ) ,
}
)
db . session . commit ( )
@classmethod
@classmethod
def _extract_variable_selector_to_variable_mapping (
def _extract_variable_selector_to_variable_mapping (
cls ,
cls ,
@ -902,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_mapping [ " #context# " ] = node_data . context . variable_selector
variable_mapping [ " #context# " ] = node_data . context . variable_selector
if node_data . vision . enabled :
if node_data . vision . enabled :
variable_mapping [ " #files# " ] = [ " sys " , SystemVariableKey . FILES . value ]
variable_mapping [ " #files# " ] = node_data . vision . configs . variable_selector
if node_data . memory :
if node_data . memory :
variable_mapping [ " #sys.query# " ] = [ " sys " , SystemVariableKey . QUERY . value ]
variable_mapping [ " #sys.query# " ] = [ " sys " , SystemVariableKey . QUERY . value ]
@ -1184,32 +1064,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json . JSONDecodeError :
except json . JSONDecodeError :
raise LLMNodeError ( " structured_output_schema is not valid JSON format " )
raise LLMNodeError ( " structured_output_schema is not valid JSON format " )
def _check_model_structured_output_support ( self ) - > SupportStructuredOutputStatus :
"""
Check if the current model supports structured output .
Returns :
SupportStructuredOutput : The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance ( self . node_data , LLMNodeData )
or not self . node_data . structured_output_enabled
or not self . node_data . structured_output
) :
return SupportStructuredOutputStatus . DISABLED
# Get model schema and check if it exists
model_schema = self . _fetch_model_schema ( self . node_data . model . provider )
if not model_schema :
return SupportStructuredOutputStatus . DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus . SUPPORTED
if bool ( model_schema . features and ModelFeature . STRUCTURED_OUTPUT in model_schema . features )
else SupportStructuredOutputStatus . UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown (
def _save_multimodal_output_and_convert_result_to_markdown (
self ,
self ,
contents : str | list [ PromptMessageContentUnionTypes ] | None ,
contents : str | list [ PromptMessageContentUnionTypes ] | None ,