refactor: enhance structured output handling in LLM generator

- Introduced `LLMResultChunkWithStructuredOutput` to encapsulate structured output within result chunks, improving data organization.
- Updated the `invoke_llm_with_structured_output` function to yield the new result chunk type, enhancing the clarity of output handling.
- Modified the `LLMStructuredOutput` class to allow optional structured output, increasing flexibility in response formats.
- Added a new request model `RequestInvokeLLMWithStructuredOutput` to facilitate structured output requests, improving API usability.
pull/21565/head
Yeuoly 11 months ago
parent 9dade0ad5a
commit e5719de784

@ -12,10 +12,16 @@ from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import ( from core.model_runtime.entities.llm_entities import (
LLMResult, LLMResult,
LLMResultChunk, LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput, LLMResultWithStructuredOutput,
LLMStructuredOutput,
) )
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, SystemPromptMessage from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
@ -34,7 +40,7 @@ def invoke_llm_with_structured_output(
stream: Literal[True] = True, stream: Literal[True] = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Generator[LLMResultChunk | LLMStructuredOutput, None, None]: ... ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload @overload
@ -66,7 +72,7 @@ def invoke_llm_with_structured_output(
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunk | LLMStructuredOutput, None, None]: ... ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output( def invoke_llm_with_structured_output(
@ -81,7 +87,7 @@ def invoke_llm_with_structured_output(
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunk | LLMStructuredOutput, None, None]: ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
""" """
Invoke large language model with structured output Invoke large language model with structured output
1. This method invokes model_instance.invoke_llm with json_schema 1. This method invokes model_instance.invoke_llm with json_schema
@ -143,14 +149,36 @@ def invoke_llm_with_structured_output(
) )
else: else:
def generator() -> Generator[LLMStructuredOutput, None, None]: def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
result_text = "" result_text = ""
prompt_messages = []
system_fingerprint = None
for event in llm_result: for event in llm_result:
if isinstance(event, LLMResultChunk): if isinstance(event, LLMResultChunk):
if isinstance(event.delta.message.content, str): if isinstance(event.delta.message.content, str):
result_text += event.delta.message.content result_text += event.delta.message.content
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=event.delta,
)
yield LLMStructuredOutput(structured_output=_parse_structured_output(result_text)) yield LLMResultChunkWithStructuredOutput(
structured_output=_parse_structured_output(result_text),
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
usage=None,
finish_reason=None,
),
)
return generator() return generator()

@ -106,7 +106,7 @@ class LLMStructuredOutput(BaseModel):
Model class for llm structured output. Model class for llm structured output.
""" """
structured_output: Mapping[str, Any] structured_output: Optional[Mapping[str, Any]] = None
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
@ -137,6 +137,12 @@ class LLMResultChunk(BaseModel):
delta: LLMResultChunkDelta delta: LLMResultChunkDelta
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
"""
Model class for llm result chunk with structured output.
"""
class NumTokensResult(PriceInfo): class NumTokensResult(PriceInfo):
""" """
Model class for number of tokens result. Model class for number of tokens result.

@ -3,7 +3,11 @@ from binascii import hexlify, unhexlify
from collections.abc import Generator from collections.abc import Generator
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
SystemPromptMessage, SystemPromptMessage,

@ -82,6 +82,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
return v return v
class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
"""
Request to invoke LLM with structured output
"""
structured_output_schema: dict[str, Any] = Field(
default_factory=dict, description="The schema of the structured output in JSON schema format"
)
class RequestInvokeTextEmbedding(BaseRequestInvokeModel): class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
""" """
Request to invoke text embedding Request to invoke text embedding

@ -18,7 +18,13 @@ from core.model_runtime.entities import (
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMStructuredOutput, LLMUsage from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMStructuredOutput,
LLMUsage,
)
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessageContentUnionTypes, PromptMessageContentUnionTypes,
@ -344,6 +350,8 @@ class LLMNode(BaseNode[LLMNodeData]):
# Consume the invoke result and handle generator exception # Consume the invoke result and handle generator exception
try: try:
for result in invoke_result: for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput):
yield result
if isinstance(result, LLMResultChunk): if isinstance(result, LLMResultChunk):
contents = result.delta.message.content contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
@ -363,8 +371,6 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = result.delta.usage usage = result.delta.usage
if finish_reason is None and result.delta.finish_reason: if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason finish_reason = result.delta.finish_reason
elif isinstance(result, LLMStructuredOutput):
yield result
except OutputParserError as e: except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}") raise LLMNodeError(f"Failed to parse structured output: {e}")

Loading…
Cancel
Save