From d6d1899c772ebf004aa847cb0a4e3740551e72ab Mon Sep 17 00:00:00 2001 From: hjlarry Date: Sun, 27 Apr 2025 10:29:55 +0800 Subject: [PATCH] fix LLMResultChunk cause concatenate str and list exception --- .../__base/large_language_model.py | 4 +++- api/core/model_runtime/utils/helper.py | 16 ++++++++++++++++ api/core/workflow/nodes/llm/node.py | 17 ++++------------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1b799131e7..c6d017ec01 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -20,6 +20,7 @@ from core.model_runtime.entities.model_entities import ( PriceType, ) from core.model_runtime.model_providers.__base.ai_model import AIModel +from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.manager.model import PluginModelManager logger = logging.getLogger(__name__) @@ -280,7 +281,8 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) - assistant_message.content += chunk.delta.message.content + text = convert_llm_result_chunk_to_str(chunk.delta.message.content) + assistant_message.content += text real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 5e8a723ec7..3035e015a3 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -1,5 +1,6 @@ import pydantic from pydantic import BaseModel +from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes def dump_model(model: BaseModel) -> dict: @@ -8,3 +9,18 @@ def dump_model(model: BaseModel) -> dict: return pydantic.model_dump(model) # type: ignore else: return model.model_dump() + + +def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str: + if content is None: + message_text = "" + elif isinstance(content, str): + message_text = content + elif isinstance(content, list): + # Assuming the list contains PromptMessageContent objects with a "data" attribute + message_text = "".join( + item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content + ) + else: + message_text = str(content) + return message_text \ No newline at end of file diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1089e7168e..9552a51dcf 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -38,6 +38,7 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -267,20 +268,10 @@ class LLMNode(BaseNode[LLMNodeData]): return self._handle_invoke_result(invoke_result=invoke_result) + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: if isinstance(invoke_result, LLMResult): - content = invoke_result.message.content - if content is None: - message_text = "" - elif isinstance(content, str): - message_text = content - elif isinstance(content, list): - # Assuming the list contains PromptMessageContent objects with a "data" attribute - message_text = "".join( - item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content - ) - else: - message_text = str(content) + message_text = convert_llm_result_chunk_to_str(invoke_result.message.content) yield ModelInvokeCompletedEvent( text=message_text, @@ -295,7 +286,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = None finish_reason = None for result in invoke_result: - text = result.delta.message.content + text = convert_llm_result_chunk_to_str(result.delta.message.content) full_text += text yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])