diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 28ccd271e2..0b6834acf3 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Optional, cast +from typing import Optional, Union, cast from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -11,7 +11,6 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import Tracer from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.trace import SpanContext, TraceFlags, TraceState @@ -36,16 +35,17 @@ from models.workflow import WorkflowNodeExecutionModel logger = logging.getLogger(__name__) -def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[Tracer, SimpleSpanProcessor]: +def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]: """Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix.""" try: # Choose the appropriate exporter based on config type + exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter] if isinstance(arize_phoenix_config, ArizeConfig): arize_endpoint = f"{arize_phoenix_config.endpoint}/v1" arize_headers = { - "api_key": arize_phoenix_config.api_key, - "space_id": arize_phoenix_config.space_id, - "authorization": f"Bearer {arize_phoenix_config.api_key}", + "api_key": arize_phoenix_config.api_key or "", + "space_id": arize_phoenix_config.space_id or "", + "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", } exporter = GrpcOTLPSpanExporter( endpoint=arize_endpoint, @@ -55,8 +55,8 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[Tra else: phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces" phoenix_headers = { - "api_key": arize_phoenix_config.api_key, - "authorization": f"Bearer {arize_phoenix_config.api_key}", + "api_key": arize_phoenix_config.api_key or "", + "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", } exporter = HttpOTLPSpanExporter( endpoint=phoenix_endpoint, @@ -65,8 +65,8 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[Tra ) attributes = { - "openinference.project.name": arize_phoenix_config.project, - "model_id": arize_phoenix_config.project, + "openinference.project.name": arize_phoenix_config.project or "", + "model_id": arize_phoenix_config.project or "", } resource = Resource(attributes=attributes) provider = trace_sdk.TracerProvider(resource=resource) @@ -78,7 +78,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[Tra # Create a named tracer instead of setting the global provider tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}" logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}") - return trace.get_tracer(tracer_name, tracer_provider=provider), processor + return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor except Exception as e: logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True) raise @@ -146,13 +146,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance): return workflow_metadata = { - "workflow_id": trace_info.workflow_run_id, - "message_id": trace_info.message_id, - "workflow_app_log_id": trace_info.workflow_app_log_id, - "status": trace_info.workflow_run_status, + "workflow_id": trace_info.workflow_run_id or "", + "message_id": trace_info.message_id or "", + "workflow_app_log_id": trace_info.workflow_app_log_id or "", + "status": trace_info.workflow_run_status or "", "status_message": trace_info.error or "", "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens, + "total_tokens": trace_info.total_tokens or 0, } workflow_metadata.update(trace_info.metadata) @@ -173,7 +173,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.conversation_id, + SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), context=trace.set_span_in_context(trace.NonRecordingSpan(context)), @@ -232,7 +232,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}", SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind, SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.conversation_id, + SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(created_at), ) @@ -272,18 +272,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance): file_list.append(file_url) message_metadata = { - "message_id": trace_info.message_id, - "conversation_mode": str(trace_info.conversation_mode), - "user_id": trace_info.message_data.from_account_id, - "file_list": file_list, - "status": trace_info.message_data.status, + "message_id": trace_info.message_id or "", + "conversation_mode": str(trace_info.conversation_mode or ""), + "user_id": trace_info.message_data.from_account_id or "", + "file_list": json.dumps(file_list), + "status": trace_info.message_data.status or "", "status_message": trace_info.error or "", "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens, - "prompt_tokens": trace_info.message_tokens, - "completion_tokens": trace_info.answer_tokens, - "ls_provider": trace_info.message_data.model_provider, - "ls_model_name": trace_info.message_data.model_id, + "total_tokens": trace_info.total_tokens or 0, + "prompt_tokens": trace_info.message_tokens or 0, + "completion_tokens": trace_info.answer_tokens or 0, + "ls_provider": trace_info.message_data.model_provider or "", + "ls_model_name": trace_info.message_data.model_id or "", } message_metadata.update(trace_info.metadata) @@ -317,7 +317,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): name=TraceTaskName.MESSAGE_TRACE.value, attributes=attributes, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(span_context), + context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), ) try: @@ -386,7 +386,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): name="llm", attributes=llm_attributes, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(message_span), + context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), ) try: @@ -474,8 +474,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "status_message": trace_info.error or "", "level": "ERROR" if trace_info.error else "DEFAULT", "total_tokens": trace_info.total_tokens, - "ls_provider": trace_info.model_provider, - "ls_model_name": trace_info.model_id, + "ls_provider": trace_info.model_provider or "", + "ls_model_name": trace_info.model_id or "", } metadata.update(trace_info.metadata) @@ -527,8 +527,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "status": trace_info.message_data.status, "status_message": trace_info.message_data.error or "", "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - "ls_provider": trace_info.message_data.model_provider, - "ls_model_name": trace_info.message_data.model_id, + "ls_provider": trace_info.message_data.model_provider or "", + "ls_model_name": trace_info.message_data.model_id or "", } metadata.update(trace_info.metadata) @@ -549,8 +549,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value, SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - "start_time": start_time.isoformat(), - "end_time": end_time.isoformat(), + "start_time": start_time.isoformat() if start_time else "", + "end_time": end_time.isoformat() if end_time else "", }, start_time=datetime_to_nanos(start_time), context=trace.set_span_in_context(trace.NonRecordingSpan(context)), @@ -594,6 +594,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): trace_state=TraceState(), ) + tool_params_str = ( + json.dumps(trace_info.tool_parameters, ensure_ascii=False) + if isinstance(trace_info.tool_parameters, dict) + else str(trace_info.tool_parameters) + ) + span = self.tracer.start_span( name=trace_info.tool_name, attributes={ @@ -602,10 +608,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), SpanAttributes.TOOL_NAME: trace_info.tool_name, - SpanAttributes.TOOL_PARAMETERS: trace_info.tool_parameters, + SpanAttributes.TOOL_PARAMETERS: tool_params_str, }, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(span_context), + context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), ) try: @@ -652,8 +658,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, - "start_time": trace_info.start_time.isoformat(), - "end_time": trace_info.end_time.isoformat(), + "start_time": trace_info.start_time.isoformat() if trace_info.start_time else "", + "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "", }, start_time=datetime_to_nanos(trace_info.start_time), context=trace.set_span_in_context(trace.NonRecordingSpan(context)),