fix: construct llm attributes via _construct_llm_attributes

pull/22547/head
kurokobo 10 months ago
parent 643bd94539
commit e85680460d

@ -201,20 +201,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
inputs = node_execution.inputs or "{}"
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
provider = process_data.get("model_provider")
model = process_data.get("model_name")
prompts = process_data.get("prompts")
if provider:
node_metadata["ls_provider"] = provider
if model:
node_metadata["ls_model_name"] = model
if prompts:
inputs = prompts
outputs = json.loads(node_execution.outputs).get("usage", {})
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
@ -230,7 +226,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(inputs, ensure_ascii=False),
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
@ -242,27 +238,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
try:
if node_execution.node_type == "llm":
llm_attributes = {
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
}
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
outputs = json.loads(node_execution.outputs).get("usage", {})
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
"completion_tokens", 0
)
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
@ -354,25 +350,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
if isinstance(trace_info.inputs, list):
for i, msg in enumerate(trace_info.inputs):
if isinstance(msg, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
"role", "user"
)
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(trace_info.inputs, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(trace_info.inputs, str):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
@ -726,3 +704,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
.all()
)
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str) -> dict:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}
if isinstance(prompts, list):
for i, msg in enumerate(prompts):
if isinstance(msg, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(prompts, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(prompts, str):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
return attributes

Loading…
Cancel
Save