Merge branch 'langgenius:main' into feat/description-field-for-env-var
commit
bbe88c4541
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
version: str = Field(description="Dify version", default="")
|
||||
|
||||
|
||||
class PyProjectTomlConfig(BaseSettings):
|
||||
"""
|
||||
configs in api/pyproject.toml
|
||||
"""
|
||||
|
||||
project: PyProjectConfig = Field(
|
||||
description="configs in the project section of pyproject.toml",
|
||||
default=PyProjectConfig(),
|
||||
)
|
||||
@ -0,0 +1,380 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Optional, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
if model_schema.support_structure_output:
|
||||
model_parameters = _handle_native_json_schema(
|
||||
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||
|
||||
# handle prompt based schema
|
||||
prompt_messages = _handle_prompt_based_schema(
|
||||
prompt_messages=prompt_messages,
|
||||
structured_output_schema=json_schema,
|
||||
)
|
||||
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
if not isinstance(llm_result.message.content, str):
|
||||
raise OutputParserError(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
try:
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
@ -0,0 +1,720 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = GrpcOTLPSpanExporter(
|
||||
endpoint=arize_endpoint,
|
||||
headers=arize_headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=phoenix_endpoint,
|
||||
headers=phoenix_headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
attributes = {
|
||||
"openinference.project.name": arize_phoenix_config.project or "",
|
||||
"model_id": arize_phoenix_config.project or "",
|
||||
}
|
||||
resource = Resource(attributes=attributes)
|
||||
provider = trace_sdk.TracerProvider(resource=resource)
|
||||
processor = SimpleSpanProcessor(
|
||||
exporter,
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# Create a named tracer instead of setting the global provider
|
||||
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
||||
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
|
||||
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def datetime_to_nanos(dt: Optional[datetime]) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def uuid_to_trace_id(string: Optional[str]) -> int:
|
||||
"""Convert UUID string to a valid trace ID (16-byte integer)."""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash
|
||||
digest = hash_object.digest()[:16]
|
||||
|
||||
# Convert to integer (128 bits)
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
workflow_metadata = {
|
||||
"workflow_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM.value
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_metadata["ls_provider"] = provider
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {}
|
||||
if usage:
|
||||
node_metadata["total_tokens"] = usage.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
|
||||
if model:
|
||||
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
|
||||
|
||||
usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {}
|
||||
if usage:
|
||||
node_span.set_attribute(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens", 0))
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage.get("prompt_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage.get("completion_tokens", 0)
|
||||
)
|
||||
finally:
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_metadata = {
|
||||
"message_id": trace_info.message_id or "",
|
||||
"conversation_mode": str(trace_info.conversation_mode or ""),
|
||||
"user_id": trace_info.message_data.from_account_id or "",
|
||||
"file_list": json.dumps(file_list),
|
||||
"status": trace_info.message_data.status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
"prompt_tokens": trace_info.message_tokens or 0,
|
||||
"completion_tokens": trace_info.answer_tokens or 0,
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
message_metadata.update(trace_info.metadata)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
if isinstance(trace_info.inputs, list):
|
||||
for i, msg in enumerate(trace_info.inputs):
|
||||
if isinstance(msg, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
|
||||
"role", "user"
|
||||
)
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(trace_info.inputs, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(trace_info.inputs, str):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
|
||||
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
|
||||
|
||||
if trace_info.message_data.model_id is not None:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
|
||||
if trace_info.message_data.model_provider is not None:
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = json.loads(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "moderation",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(
|
||||
{
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "suggested_question",
|
||||
"status": trace_info.status,
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"ls_provider": trace_info.model_provider or "",
|
||||
"ls_model_name": trace_info.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "dataset_retrieval",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
"start_time": start_time.isoformat() if start_time else "",
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
if isinstance(trace_info.tool_parameters, dict)
|
||||
else str(trace_info.tool_parameters)
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"project_name": self.project,
|
||||
"message_id": trace_info.message_id,
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
span.set_attribute("test", "true")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
|
||||
return "https://app.arize.com/"
|
||||
else:
|
||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
@ -0,0 +1,45 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
def fetch_dynamic_select_options(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
Fetch dynamic select options for a plugin parameter.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options",
|
||||
PluginDynamicSelectOptionsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for options in response:
|
||||
return options
|
||||
|
||||
raise ValueError(f"Plugin service returned no options for parameter '{parameter}' in provider '{provider}'")
|
||||
@ -0,0 +1,32 @@
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class DraftVariableSaver(Protocol):
|
||||
@abc.abstractmethod
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
|
||||
|
||||
class DraftVariableSaverFactory(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> "DraftVariableSaver":
|
||||
pass
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
@ -1,16 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
@ -0,0 +1,30 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def search_file_upwards(
|
||||
base_dir_path: Path,
|
||||
target_file_name: str,
|
||||
max_search_parent_depth: int,
|
||||
) -> Path:
|
||||
"""
|
||||
Find a target file in the current directory or its parent directories up to a specified depth.
|
||||
:param base_dir_path: Starting directory path to search from.
|
||||
:param target_file_name: Name of the file to search for.
|
||||
:param max_search_parent_depth: Maximum number of parent directories to search upwards.
|
||||
:return: Path of the file if found, otherwise None.
|
||||
"""
|
||||
current_path = base_dir_path.resolve()
|
||||
for _ in range(max_search_parent_depth):
|
||||
candidate_path = current_path / target_file_name
|
||||
if candidate_path.is_file():
|
||||
return candidate_path
|
||||
parent_path = current_path.parent
|
||||
if parent_path == current_path: # reached the root directory
|
||||
break
|
||||
else:
|
||||
current_path = parent_path
|
||||
|
||||
raise ValueError(
|
||||
f"File '{target_file_name}' not found in the directory '{base_dir_path.resolve()}' or its parent directories"
|
||||
f" in depth of {max_search_parent_depth}."
|
||||
)
|
||||
@ -0,0 +1,74 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
|
||||
class PluginParameterService:
|
||||
@staticmethod
|
||||
def get_dynamic_select_options(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
parameter: str,
|
||||
provider_type: Literal["tool"],
|
||||
) -> Sequence[PluginParameterOption]:
|
||||
"""
|
||||
Get dynamic select options for a plugin parameter.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID.
|
||||
plugin_id: The plugin ID.
|
||||
provider: The provider name.
|
||||
action: The action name.
|
||||
parameter: The parameter name.
|
||||
"""
|
||||
credentials: Mapping[str, Any] = {}
|
||||
|
||||
match provider_type:
|
||||
case "tool":
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# check if credentials are required
|
||||
if not provider_controller.need_credentials:
|
||||
credentials = {}
|
||||
else:
|
||||
# fetch credentials from db
|
||||
with Session(db.engine) as session:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_record is None:
|
||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||
|
||||
credentials = tool_configuration.decrypt(db_record.credentials)
|
||||
case _:
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
|
||||
return (
|
||||
DynamicSelectClient()
|
||||
.fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
|
||||
.options
|
||||
)
|
||||
@ -0,0 +1,259 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
"""Test class for WorkflowResponseConverter._fetch_files_from_variable_value method"""
|
||||
|
||||
def create_test_file(self, file_id: str = "test_file_1") -> File:
|
||||
"""Create a test File object"""
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="test_tenant",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related_123",
|
||||
filename=f"{file_id}.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="storage_key_123",
|
||||
)
|
||||
|
||||
def create_file_dict(self, file_id: str = "test_file_dict") -> dict:
|
||||
"""Create a file dictionary with correct dify_model_identity"""
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"id": file_id,
|
||||
"tenant_id": "test_tenant",
|
||||
"type": "document",
|
||||
"transfer_method": "local_file",
|
||||
"related_id": "related_456",
|
||||
"filename": f"{file_id}.txt",
|
||||
"extension": ".txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 2048,
|
||||
"url": "http://example.com/file.txt",
|
||||
}
|
||||
|
||||
def test_fetch_files_from_variable_value_with_none(self):
|
||||
"""Test with None input"""
|
||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||
# We'll test the actual behavior by passing an empty dict instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||
"""Test with empty dictionary"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value({})
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_list(self):
|
||||
"""Test with empty list"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_file_segment(self):
|
||||
"""Test with valid FileSegment"""
|
||||
test_file = self.create_test_file("segment_file")
|
||||
file_segment = FileSegment(value=test_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "segment_file"
|
||||
assert result[0]["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_single(self):
|
||||
"""Test with ArrayFileSegment containing single file"""
|
||||
test_file = self.create_test_file("array_file_1")
|
||||
array_segment = ArrayFileSegment(value=[test_file])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_multiple(self):
|
||||
"""Test with ArrayFileSegment containing multiple files"""
|
||||
test_file_1 = self.create_test_file("array_file_1")
|
||||
test_file_2 = self.create_test_file("array_file_2")
|
||||
array_segment = ArrayFileSegment(value=[test_file_1, test_file_2])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
assert result[1]["id"] == "array_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_empty(self):
|
||||
"""Test with ArrayFileSegment containing empty array"""
|
||||
array_segment = ArrayFileSegment(value=[])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_dicts(self):
|
||||
"""Test with list containing file dictionaries"""
|
||||
file_dict_1 = self.create_file_dict("list_file_1")
|
||||
file_dict_2 = self.create_file_dict("list_file_2")
|
||||
test_list = [file_dict_1, file_dict_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_file_1"
|
||||
assert result[1]["id"] == "list_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_objects(self):
|
||||
"""Test with list containing File objects"""
|
||||
file_obj_1 = self.create_test_file("list_obj_1")
|
||||
file_obj_2 = self.create_test_file("list_obj_2")
|
||||
test_list = [file_obj_1, file_obj_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_obj_1"
|
||||
assert result[1]["id"] == "list_obj_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_mixed_valid_invalid(self):
|
||||
"""Test with list containing mix of valid files and invalid items"""
|
||||
file_dict = self.create_file_dict("mixed_file")
|
||||
invalid_dict = {"not_a_file": "value"}
|
||||
test_list = [file_dict, invalid_dict, "string_item", 123]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == "mixed_file"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_nested_structures(self):
|
||||
"""Test with list containing nested structures"""
|
||||
file_dict = self.create_file_dict("nested_file")
|
||||
nested_list = [file_dict, ["inner_list"]]
|
||||
test_list = [nested_list, {"nested": "dict"}]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
# Should not process nested structures in list items
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_incorrect_identity(self):
|
||||
"""Test with dictionary having incorrect dify_model_identity"""
|
||||
invalid_dict = {"dify_model_identity": "wrong_identity", "id": "invalid_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_missing_identity(self):
|
||||
"""Test with dictionary missing dify_model_identity"""
|
||||
invalid_dict = {"id": "no_identity_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_file_object(self):
|
||||
"""Test with dictionary containing File object"""
|
||||
file_obj = self.create_test_file("dict_obj_file")
|
||||
test_dict = {"file_key": file_obj}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_dict)
|
||||
|
||||
# Should not extract File objects from dict values
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_mixed_data_types(self):
|
||||
"""Test with various mixed data types"""
|
||||
mixed_data = {"string": "text", "number": 42, "boolean": True, "null": None, "dify_model_identity": "wrong"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(mixed_data)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_invalid_objects(self):
|
||||
"""Test with invalid objects that are not supported types"""
|
||||
# Test with an invalid dict that doesn't match expected patterns
|
||||
invalid_dict = {"custom_key": "custom_value"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_string_input(self):
|
||||
"""Test with string input (unsupported type)"""
|
||||
# Since method expects Union[dict, list, Segment], test with empty list instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_number_input(self):
|
||||
"""Test with number input (unsupported type)"""
|
||||
# Test with list containing numbers (should be ignored)
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([42, "string", None])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_return_type_is_sequence(self):
|
||||
"""Test that return type is Sequence[Mapping[str, Any]]"""
|
||||
file_dict = self.create_file_dict("type_test_file")
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_dict)
|
||||
|
||||
assert isinstance(result, Sequence)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Mapping)
|
||||
assert all(isinstance(key, str) for key in result[0])
|
||||
|
||||
def test_fetch_files_from_variable_value_preserves_file_properties(self):
|
||||
"""Test that all file properties are preserved in the result"""
|
||||
original_file = self.create_test_file("property_test")
|
||||
file_segment = FileSegment(value=original_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
file_dict = result[0]
|
||||
assert file_dict["id"] == "property_test"
|
||||
assert file_dict["tenant_id"] == "test_tenant"
|
||||
assert file_dict["type"] == "document"
|
||||
assert file_dict["transfer_method"] == "local_file"
|
||||
assert file_dict["filename"] == "property_test.txt"
|
||||
assert file_dict["extension"] == ".txt"
|
||||
assert file_dict["mime_type"] == "text/plain"
|
||||
assert file_dict["size"] == 1024
|
||||
|
||||
def test_fetch_files_from_variable_value_with_complex_nested_scenario(self):
|
||||
"""Test complex scenario with nested valid and invalid data"""
|
||||
file_dict = self.create_file_dict("complex_file")
|
||||
file_obj = self.create_test_file("complex_obj")
|
||||
|
||||
# Complex nested structure
|
||||
complex_data = [
|
||||
file_dict, # Valid file dict
|
||||
file_obj, # Valid file object
|
||||
{ # Invalid dict
|
||||
"not_file": "data",
|
||||
"nested": {"deep": "value"},
|
||||
},
|
||||
[ # Nested list (should be ignored)
|
||||
self.create_file_dict("nested_file")
|
||||
],
|
||||
"string", # Invalid string
|
||||
None, # None value
|
||||
42, # Invalid number
|
||||
]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(complex_data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "complex_file"
|
||||
assert result[1]["id"] == "complex_obj"
|
||||
@ -0,0 +1,465 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||
"""Create a mock LLMUsage with all required fields"""
|
||||
return LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1"),
|
||||
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1"),
|
||||
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
currency="USD",
|
||||
latency=1.5,
|
||||
)
|
||||
|
||||
|
||||
def get_model_entity(provider: str, model_name: str, support_structure_output: bool = False) -> AIModelEntity:
|
||||
"""Create a mock AIModelEntity for testing"""
|
||||
model_schema = MagicMock()
|
||||
model_schema.model = model_name
|
||||
model_schema.provider = provider
|
||||
model_schema.model_type = ModelType.LLM
|
||||
model_schema.model_provider = provider
|
||||
model_schema.model_name = model_name
|
||||
model_schema.support_structure_output = support_structure_output
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
return model_schema
|
||||
|
||||
|
||||
def get_model_instance() -> MagicMock:
|
||||
"""Create a mock ModelInstance for testing"""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.provider = "openai"
|
||||
mock_instance.credentials = {}
|
||||
return mock_instance
|
||||
|
||||
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"answer": "test response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
{
|
||||
"name": "error_non_string_content_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content=None), # Non-string content
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": None,
|
||||
"should_raise": True,
|
||||
"expected_error": OutputParserError,
|
||||
},
|
||||
# Test case 7: JSON repair scenario
|
||||
{
|
||||
"name": "json_repair_scenario",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"'), # Invalid JSON - missing closing brace
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 8: Model with parameter rules for response format
|
||||
{
|
||||
"name": "model_with_parameter_rules",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"result": "success"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 9: Model without native support but with JSON response format rules
|
||||
{
|
||||
"name": "non_native_with_json_rules",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"output": "result"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
print(f"Running test case: {case['name']}")
|
||||
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||
|
||||
# Setup prompt messages
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Generate a response according to the schema."),
|
||||
]
|
||||
|
||||
if case["should_raise"]:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
||||
def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
}
|
||||
|
||||
# Setup for list parsing test
|
||||
model_schema = get_model_entity(
|
||||
testcase_list_with_dict["provider"],
|
||||
testcase_list_with_dict["model_name"],
|
||||
testcase_list_with_dict["support_structure_output"],
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=testcase_list_with_dict["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
stream=testcase_list_with_dict["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"thought": "reasoning process"}
|
||||
|
||||
|
||||
def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||
}
|
||||
|
||||
model_schema = get_model_entity(
|
||||
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gemini-pro",
|
||||
message=AssistantPromptMessage(content='{"result": "true"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test")]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=gemini_case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
stream=gemini_case["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
|
||||
# Verify model_instance.invoke_llm was called and check the schema preparation
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
|
||||
assert "json_schema" in call_args.kwargs["model_parameters"]
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue