@ -1,4 +1,5 @@
import json
import logging
from collections . abc import Generator , Mapping , Sequence
from typing import TYPE_CHECKING , Any , Optional , cast
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core . entities . model_entities import ModelStatus
from core . entities . provider_entities import QuotaUnit
from core . errors . error import ModelCurrentlyNotSupportError , ProviderTokenNotInitError , QuotaExceededError
from core . file import FileType , file_manager
from core . helper . code_executor import CodeExecutor , CodeLanguage
from core . memory . token_buffer_memory import TokenBufferMemory
from core . model_manager import ModelInstance , ModelManager
from core . model_runtime . entities import (
AudioPromptMessageContent ,
ImagePromptMessageContent ,
PromptMessage ,
PromptMessageContentType ,
TextPromptMessageContent ,
VideoPromptMessageContent ,
)
from core . model_runtime . entities . llm_entities import LLMResult , LLMUsage
from core . model_runtime . entities . model_entities import ModelType
from core . model_runtime . entities . message_entities import (
AssistantPromptMessage ,
PromptMessageRole ,
SystemPromptMessage ,
UserPromptMessage ,
)
from core . model_runtime . entities . model_entities import ModelFeature , ModelPropertyKey , ModelType
from core . model_runtime . model_providers . __base . large_language_model import LargeLanguageModel
from core . model_runtime . utils . encoders import jsonable_encoder
from core . prompt . advanced_prompt_transform import AdvancedPromptTransform
from core . prompt . entities . advanced_prompt_entities import CompletionModelPromptTemplate , MemoryConfig
from core . prompt . utils . prompt_message_util import PromptMessageUtil
from core . variables import (
@ -34,6 +40,8 @@ from core.variables import (
)
from core . workflow . constants import SYSTEM_VARIABLE_NODE_ID
from core . workflow . entities . node_entities import NodeRunMetadataKey , NodeRunResult
from core . workflow . entities . variable_entities import VariableSelector
from core . workflow . entities . variable_pool import VariablePool
from core . workflow . enums import SystemVariableKey
from core . workflow . graph_engine . entities . event import InNodeEvent
from core . workflow . nodes . base import BaseNode
@ -58,18 +66,23 @@ from .entities import (
ModelConfig ,
)
from . exc import (
FileTypeNotSupportError ,
InvalidContextStructureError ,
InvalidVariableTypeError ,
LLMModeRequiredError ,
LLMNodeError ,
MemoryRolePrefixRequiredError ,
ModelNotExistError ,
NoPromptFoundError ,
TemplateTypeNotSupportError ,
VariableNotFoundError ,
)
if TYPE_CHECKING :
from core . file . models import File
logger = logging . getLogger ( __name__ )
class LLMNode ( BaseNode [ LLMNodeData ] ) :
_node_data_cls = LLMNodeData
@ -121,19 +134,19 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch memory
memory = self . _fetch_memory ( node_data_memory = self . node_data . memory , model_instance = model_instance )
# fetch prompt messages
if self . node_data . memory :
query = self . graph_runtime_state . variable_pool . get ( ( SYSTEM_VARIABLE_NODE_ID , SystemVariableKey . QUERY ) )
if not query :
raise VariableNotFoundError ( " Query not found " )
query = query . text
else :
query = None
if self . node_data . memory :
query = self . node_data . memory . query_prompt_template
if query is None and (
query_variable := self . graph_runtime_state . variable_pool . get (
( SYSTEM_VARIABLE_NODE_ID , SystemVariableKey . QUERY )
)
) :
query = query_variable . text
prompt_messages , stop = self . _fetch_prompt_messages (
system_query = query ,
inputs = inputs ,
files = files ,
user_query = query ,
user_files = files ,
context = context ,
memory = memory ,
model_config = model_config ,
@ -141,6 +154,8 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config = self . node_data . memory ,
vision_enabled = self . node_data . vision . enabled ,
vision_detail = self . node_data . vision . configs . detail ,
variable_pool = self . graph_runtime_state . variable_pool ,
jinja2_variables = self . node_data . prompt_config . jinja2_variables ,
)
process_data = {
@ -181,6 +196,17 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
return
except Exception as e :
logger . exception ( f " Node { self . node_id } failed to run " )
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
error = str ( e ) ,
inputs = node_inputs ,
process_data = process_data ,
)
)
return
outputs = { " text " : result_text , " usage " : jsonable_encoder ( usage ) , " finish_reason " : finish_reason }
@ -203,8 +229,8 @@ class LLMNode(BaseNode[LLMNodeData]):
self ,
node_data_model : ModelConfig ,
model_instance : ModelInstance ,
prompt_messages : list [ PromptMessage ] ,
stop : Optional [ list [ str ] ] = None ,
prompt_messages : Sequence [ PromptMessage ] ,
stop : Optional [ Sequence [ str ] ] = None ,
) - > Generator [ NodeEvent , None , None ] :
db . session . close ( )
@ -519,9 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages (
self ,
* ,
system_query : str | None = None ,
inputs : dict [ str , str ] | None = None ,
files : Sequence [ " File " ] ,
user_query : str | None = None ,
user_files : Sequence [ " File " ] ,
context : str | None = None ,
memory : TokenBufferMemory | None = None ,
model_config : ModelConfigWithCredentialsEntity ,
@ -529,58 +554,144 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config : MemoryConfig | None = None ,
vision_enabled : bool = False ,
vision_detail : ImagePromptMessageContent . DETAIL ,
) - > tuple [ list [ PromptMessage ] , Optional [ list [ str ] ] ] :
inputs = inputs or { }
prompt_transform = AdvancedPromptTransform ( with_variable_tmpl = True )
prompt_messages = prompt_transform . get_prompt (
prompt_template = prompt_template ,
inputs = inputs ,
query = system_query or " " ,
files = files ,
variable_pool : VariablePool ,
jinja2_variables : Sequence [ VariableSelector ] ,
) - > tuple [ Sequence [ PromptMessage ] , Optional [ Sequence [ str ] ] ] :
prompt_messages = [ ]
if isinstance ( prompt_template , list ) :
# For chat model
prompt_messages . extend (
_handle_list_messages (
messages = prompt_template ,
context = context ,
jinja2_variables = jinja2_variables ,
variable_pool = variable_pool ,
vision_detail_config = vision_detail ,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode (
memory = memory ,
memory_config = memory_config ,
model_config = model_config ,
)
# Extend prompt_messages with memory messages
prompt_messages . extend ( memory_messages )
# Add current query to the prompt messages
if user_query :
message = LLMNodeChatModelMessage (
text = user_query ,
role = PromptMessageRole . USER ,
edition_type = " basic " ,
)
prompt_messages . extend (
_handle_list_messages (
messages = [ message ] ,
context = " " ,
jinja2_variables = [ ] ,
variable_pool = variable_pool ,
vision_detail_config = vision_detail ,
)
)
elif isinstance ( prompt_template , LLMNodeCompletionModelPromptTemplate ) :
# For completion model
prompt_messages . extend (
_handle_completion_template (
template = prompt_template ,
context = context ,
jinja2_variables = jinja2_variables ,
variable_pool = variable_pool ,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode (
memory = memory ,
memory_config = memory_config ,
model_config = model_config ,
)
stop = model_config . stop
# Insert histories into the prompt
prompt_content = prompt_messages [ 0 ] . content
if " #histories# " in prompt_content :
prompt_content = prompt_content . replace ( " #histories# " , memory_text )
else :
prompt_content = memory_text + " \n " + prompt_content
prompt_messages [ 0 ] . content = prompt_content
# Add current query to the prompt message
if user_query :
prompt_content = prompt_messages [ 0 ] . content . replace ( " #sys.query# " , user_query )
prompt_messages [ 0 ] . content = prompt_content
else :
raise TemplateTypeNotSupportError ( type_name = str ( type ( prompt_template ) ) )
if vision_enabled and user_files :
file_prompts = [ ]
for file in user_files :
file_prompt = file_manager . to_prompt_message_content ( file , image_detail_config = vision_detail )
file_prompts . append ( file_prompt )
if (
len ( prompt_messages ) > 0
and isinstance ( prompt_messages [ - 1 ] , UserPromptMessage )
and isinstance ( prompt_messages [ - 1 ] . content , list )
) :
prompt_messages [ - 1 ] = UserPromptMessage ( content = prompt_messages [ - 1 ] . content + file_prompts )
else :
prompt_messages . append ( UserPromptMessage ( content = file_prompts ) )
# Filter prompt messages
filtered_prompt_messages = [ ]
for prompt_message in prompt_messages :
if prompt_message . is_empty ( ) :
continue
if not isinstance ( prompt_message . content , str ) :
if isinstance ( prompt_message . content , list ) :
prompt_message_content = [ ]
for content_item in prompt_message . content or [ ] :
# Skip image if vision is disabled
if not vision_enabled and content_item . type == PromptMessageContentType . IMAGE :
for content_item in prompt_message . content :
# Skip content if features are not defined
if not model_config . model_schema . features :
if content_item . type != PromptMessageContentType . TEXT :
continue
if isinstance ( content_item , ImagePromptMessageContent ) :
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item . detail = vision_detail
prompt_message_content . append ( content_item )
elif isinstance (
content_item , TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
) :
prompt_message_content . append ( content_item )
continue
if len ( prompt_message_content ) > 1 :
prompt_message . content = prompt_message_content
elif (
len ( prompt_message_content ) == 1 and prompt_message_content [ 0 ] . type == PromptMessageContentType . TEXT
# Skip content if corresponding feature is not supported
if (
(
content_item . type == PromptMessageContentType . IMAGE
and ModelFeature . VISION not in model_config . model_schema . features
)
or (
content_item . type == PromptMessageContentType . DOCUMENT
and ModelFeature . DOCUMENT not in model_config . model_schema . features
)
or (
content_item . type == PromptMessageContentType . VIDEO
and ModelFeature . VIDEO not in model_config . model_schema . features
)
or (
content_item . type == PromptMessageContentType . AUDIO
and ModelFeature . AUDIO not in model_config . model_schema . features
)
) :
raise FileTypeNotSupportError ( type_name = content_item . type )
prompt_message_content . append ( content_item )
if len ( prompt_message_content ) == 1 and prompt_message_content [ 0 ] . type == PromptMessageContentType . TEXT :
prompt_message . content = prompt_message_content [ 0 ] . data
else :
prompt_message . content = prompt_message_content
if prompt_message . is_empty ( ) :
continue
filtered_prompt_messages . append ( prompt_message )
if not filtered_prompt_messages :
if len ( filtered_prompt_messages ) == 0 :
raise NoPromptFoundError (
" No prompt found in the LLM configuration. "
" Please ensure a prompt is properly configured before proceeding. "
)
stop = model_config . stop
return filtered_prompt_messages , stop
@classmethod
@ -715,3 +826,198 @@ class LLMNode(BaseNode[LLMNodeData]):
}
} ,
}
def _combine_text_message_with_role ( * , text : str , role : PromptMessageRole ) :
match role :
case PromptMessageRole . USER :
return UserPromptMessage ( content = [ TextPromptMessageContent ( data = text ) ] )
case PromptMessageRole . ASSISTANT :
return AssistantPromptMessage ( content = [ TextPromptMessageContent ( data = text ) ] )
case PromptMessageRole . SYSTEM :
return SystemPromptMessage ( content = [ TextPromptMessageContent ( data = text ) ] )
raise NotImplementedError ( f " Role { role } is not supported " )
def _render_jinja2_message (
* ,
template : str ,
jinjia2_variables : Sequence [ VariableSelector ] ,
variable_pool : VariablePool ,
) :
if not template :
return " "
jinjia2_inputs = { }
for jinja2_variable in jinjia2_variables :
variable = variable_pool . get ( jinja2_variable . value_selector )
jinjia2_inputs [ jinja2_variable . variable ] = variable . to_object ( ) if variable else " "
code_execute_resp = CodeExecutor . execute_workflow_code_template (
language = CodeLanguage . JINJA2 ,
code = template ,
inputs = jinjia2_inputs ,
)
result_text = code_execute_resp [ " result " ]
return result_text
def _handle_list_messages (
* ,
messages : Sequence [ LLMNodeChatModelMessage ] ,
context : Optional [ str ] ,
jinja2_variables : Sequence [ VariableSelector ] ,
variable_pool : VariablePool ,
vision_detail_config : ImagePromptMessageContent . DETAIL ,
) - > Sequence [ PromptMessage ] :
prompt_messages = [ ]
for message in messages :
if message . edition_type == " jinja2 " :
result_text = _render_jinja2_message (
template = message . jinja2_text or " " ,
jinjia2_variables = jinja2_variables ,
variable_pool = variable_pool ,
)
prompt_message = _combine_text_message_with_role ( text = result_text , role = message . role )
prompt_messages . append ( prompt_message )
else :
# Get segment group from basic message
if context :
template = message . text . replace ( " { #context#} " , context )
else :
template = message . text
segment_group = variable_pool . convert_template ( template )
# Process segments for images
file_contents = [ ]
for segment in segment_group . value :
if isinstance ( segment , ArrayFileSegment ) :
for file in segment . value :
if file . type in { FileType . IMAGE , FileType . VIDEO , FileType . AUDIO , FileType . DOCUMENT } :
file_content = file_manager . to_prompt_message_content (
file , image_detail_config = vision_detail_config
)
file_contents . append ( file_content )
if isinstance ( segment , FileSegment ) :
file = segment . value
if file . type in { FileType . IMAGE , FileType . VIDEO , FileType . AUDIO , FileType . DOCUMENT } :
file_content = file_manager . to_prompt_message_content (
file , image_detail_config = vision_detail_config
)
file_contents . append ( file_content )
# Create message with text from all segments
plain_text = segment_group . text
if plain_text :
prompt_message = _combine_text_message_with_role ( text = plain_text , role = message . role )
prompt_messages . append ( prompt_message )
if file_contents :
# Create message with image contents
prompt_message = UserPromptMessage ( content = file_contents )
prompt_messages . append ( prompt_message )
return prompt_messages
def _calculate_rest_token (
* , prompt_messages : list [ PromptMessage ] , model_config : ModelConfigWithCredentialsEntity
) - > int :
rest_tokens = 2000
model_context_tokens = model_config . model_schema . model_properties . get ( ModelPropertyKey . CONTEXT_SIZE )
if model_context_tokens :
model_instance = ModelInstance (
provider_model_bundle = model_config . provider_model_bundle , model = model_config . model
)
curr_message_tokens = model_instance . get_llm_num_tokens ( prompt_messages )
max_tokens = 0
for parameter_rule in model_config . model_schema . parameter_rules :
if parameter_rule . name == " max_tokens " or (
parameter_rule . use_template and parameter_rule . use_template == " max_tokens "
) :
max_tokens = (
model_config . parameters . get ( parameter_rule . name )
or model_config . parameters . get ( str ( parameter_rule . use_template ) )
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max ( rest_tokens , 0 )
return rest_tokens
def _handle_memory_chat_mode (
* ,
memory : TokenBufferMemory | None ,
memory_config : MemoryConfig | None ,
model_config : ModelConfigWithCredentialsEntity ,
) - > Sequence [ PromptMessage ] :
memory_messages = [ ]
# Get messages from memory for chat model
if memory and memory_config :
rest_tokens = _calculate_rest_token ( prompt_messages = [ ] , model_config = model_config )
memory_messages = memory . get_history_prompt_messages (
max_token_limit = rest_tokens ,
message_limit = memory_config . window . size if memory_config . window . enabled else None ,
)
return memory_messages
def _handle_memory_completion_mode (
* ,
memory : TokenBufferMemory | None ,
memory_config : MemoryConfig | None ,
model_config : ModelConfigWithCredentialsEntity ,
) - > str :
memory_text = " "
# Get history text from memory for completion model
if memory and memory_config :
rest_tokens = _calculate_rest_token ( prompt_messages = [ ] , model_config = model_config )
if not memory_config . role_prefix :
raise MemoryRolePrefixRequiredError ( " Memory role prefix is required for completion model. " )
memory_text = memory . get_history_prompt_text (
max_token_limit = rest_tokens ,
message_limit = memory_config . window . size if memory_config . window . enabled else None ,
human_prefix = memory_config . role_prefix . user ,
ai_prefix = memory_config . role_prefix . assistant ,
)
return memory_text
def _handle_completion_template (
* ,
template : LLMNodeCompletionModelPromptTemplate ,
context : Optional [ str ] ,
jinja2_variables : Sequence [ VariableSelector ] ,
variable_pool : VariablePool ,
) - > Sequence [ PromptMessage ] :
""" Handle completion template processing outside of LLMNode class.
Args :
template : The completion model prompt template
context : Optional context string
jinja2_variables : Variables for jinja2 template rendering
variable_pool : Variable pool for template conversion
Returns :
Sequence of prompt messages
"""
prompt_messages = [ ]
if template . edition_type == " jinja2 " :
result_text = _render_jinja2_message (
template = template . jinja2_text or " " ,
jinjia2_variables = jinja2_variables ,
variable_pool = variable_pool ,
)
else :
if context :
template_text = template . text . replace ( " { #context#} " , context )
else :
template_text = template . text
result_text = variable_pool . convert_template ( template_text ) . text
prompt_message = _combine_text_message_with_role ( text = result_text , role = PromptMessageRole . USER )
prompt_messages . append ( prompt_message )
return prompt_messages