diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 302005922c..8b3ce0c448 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -238,7 +238,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): try: if node_execution.node_type == "llm": - llm_attributes = { + llm_attributes: dict[str, Any] = { SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), } provider = process_data.get("model_provider") @@ -247,7 +247,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes[SpanAttributes.LLM_PROVIDER] = provider if model: llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model - outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + outputs = ( + json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + ) usage_data = ( process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) ) @@ -705,7 +707,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) return workflow_nodes - def _construct_llm_attributes(self, prompts: dict | list | str) -> dict: + def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: """Helper method to construct LLM attributes with passed prompts.""" attributes = {} if isinstance(prompts, list):