diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 92b6cbae13..82c0654e33 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -50,7 +50,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[Tra exporter = GrpcOTLPSpanExporter( endpoint=arize_endpoint, headers=arize_headers, - timeout=30 + timeout=30, ) else: phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces" @@ -61,12 +61,12 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[Tra exporter = HttpOTLPSpanExporter( endpoint=phoenix_endpoint, headers=phoenix_headers, - timeout=30 + timeout=30, ) attributes = { "openinference.project.name": arize_phoenix_config.project, - "model_id": arize_phoenix_config.project + "model_id": arize_phoenix_config.project, } resource = Resource(attributes=attributes) provider = trace_sdk.TracerProvider(resource=resource) @@ -107,6 +107,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ): super().__init__(arize_phoenix_config) import logging + logging.basicConfig() logging.getLogger().setLevel(logging.DEBUG) self.arize_phoenix_config = arize_phoenix_config @@ -157,7 +158,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): trace_id=trace_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) workflow_span = self.tracer.start_span( @@ -180,8 +181,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - process_data = json.loads( - node_execution.process_data) if node_execution.process_data else {} + process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} node_metadata = { "node_id": node_execution.id, @@ -195,8 +195,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } if node_execution.execution_metadata: - node_metadata.update(json.loads( - node_execution.execution_metadata)) + node_metadata.update(json.loads(node_execution.execution_metadata)) # Determine the correct span kind based on node type span_kind = OpenInferenceSpanKindValues.CHAIN.value @@ -209,15 +208,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - usage = json.loads(node_execution.outputs).get( - "usage", {}) if node_execution.outputs else {} + usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {} if usage: - node_metadata["total_tokens"] = usage.get( - "total_tokens", 0) - node_metadata["prompt_tokens"] = usage.get( - "prompt_tokens", 0) - node_metadata["completion_tokens"] = usage.get( - "completion_tokens", 0) + node_metadata["total_tokens"] = usage.get("total_tokens", 0) + node_metadata["prompt_tokens"] = usage.get("prompt_tokens", 0) + node_metadata["completion_tokens"] = usage.get("completion_tokens", 0) elif node_execution.node_type == "dataset_retrieval": span_kind = OpenInferenceSpanKindValues.RETRIEVER.value elif node_execution.node_type == "tool": @@ -242,21 +237,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance): provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: - node_span.set_attribute( - SpanAttributes.LLM_PROVIDER, provider) + node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider) if model: - node_span.set_attribute( - SpanAttributes.LLM_MODEL_NAME, model) + node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model) - usage = json.loads(node_execution.outputs).get( - "usage", {}) if node_execution.outputs else {} + usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {} if usage: + node_span.set_attribute(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens", 0)) node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens", 0)) + SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage.get("prompt_tokens", 0) + ) node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage.get("prompt_tokens", 0)) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage.get("completion_tokens", 0)) + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage.get("completion_tokens", 0) + ) finally: node_span.end(end_time=datetime_to_nanos(finished_at)) finally: @@ -292,9 +285,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser) - .filter(EndUser.id == trace_info.message_data.from_end_user_id) - .first() + db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: message_metadata["end_user_id"] = end_user_data.session_id @@ -314,7 +305,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span_id=message_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) message_span = self.tracer.start_span( @@ -331,14 +322,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) # Convert outputs to string based on type if isinstance(trace_info.outputs, dict | list): - outputs_str = json.dumps( - trace_info.outputs, ensure_ascii=False) + outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False) elif isinstance(trace_info.outputs, str): outputs_str = trace_info.outputs else: @@ -355,17 +345,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if isinstance(trace_info.inputs, list): for i, msg in enumerate(trace_info.inputs): if isinstance(msg, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get( - "text", "") + llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get( - "role", "user") + "role", "user" + ) # todo: handle assistant and tool role messages, as they don't always # have a text field, but may have a tool_calls field instead # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} elif isinstance(trace_info.inputs, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps( - trace_info.inputs) + llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs) llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" elif isinstance(trace_info.inputs, str): llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs @@ -384,11 +373,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider if trace_info.message_data and trace_info.message_data.message_metadata: - metadata_dict = json.loads( - trace_info.message_data.message_metadata) + metadata_dict = json.loads(trace_info.message_data.message_metadata) if model_params := metadata_dict.get("model_parameters"): - llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps( - model_params) + llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params) llm_span = self.tracer.start_span( name="llm", @@ -404,8 +391,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) finally: llm_span.end(end_time=datetime_to_nanos(trace_info.end_time)) @@ -432,19 +419,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span_id=span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) span = self.tracer.start_span( name=TraceTaskName.MODERATION_TRACE.value, attributes={ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps({ - "action": trace_info.action, - "flagged": trace_info.flagged, - "preset_response": trace_info.preset_response, - "inputs": trace_info.inputs, - }, ensure_ascii=False), + SpanAttributes.OUTPUT_VALUE: json.dumps( + { + "action": trace_info.action, + "flagged": trace_info.flagged, + "preset_response": trace_info.preset_response, + "inputs": trace_info.inputs, + }, + ensure_ascii=False, + ), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), }, @@ -459,8 +449,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) @@ -491,7 +481,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span_id=span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) span = self.tracer.start_span( @@ -513,8 +503,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) finally: span.end(end_time=datetime_to_nanos(end_time)) @@ -544,7 +534,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span_id=span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) span = self.tracer.start_span( @@ -568,8 +558,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) finally: span.end(end_time=datetime_to_nanos(end_time)) @@ -596,7 +586,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span_id=tool_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) span = self.tracer.start_span( @@ -620,8 +610,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) @@ -646,7 +636,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span_id=span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState() + trace_state=TraceState(), ) span = self.tracer.start_span( @@ -671,8 +661,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): attributes={ "exception.message": trace_info.error, "exception.type": "Error", - "exception.stacktrace": trace_info.error - } + "exception.stacktrace": trace_info.error, + }, ) finally: span.end(end_time=datetime_to_nanos(trace_info.end_time))