chore: fix mypy errors in weave_trace

pull/14262/head
Bharat Ramanathan 1 year ago
parent d72603e742
commit e7d502cacb

@ -3,7 +3,7 @@ import logging
import os import os
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, cast from typing import Any, Optional, cast
import wandb import wandb
import weave import weave
@ -39,10 +39,14 @@ class WeaveDataTrace(BaseTraceInstance):
self.project_name = weave_config.project self.project_name = weave_config.project
self.entity = weave_config.entity self.entity = weave_config.entity
self.weave_client = weave.init( self.weave_client = weave.init(
project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name project_name=(
f"{self.entity}/{self.project_name}"
if self.entity
else self.project_name
)
) )
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls = {} self.calls: dict[str, Any] = {}
def get_project_url( def get_project_url(
self, self,
@ -103,7 +107,7 @@ class WeaveDataTrace(BaseTraceInstance):
workflow_attributes = trace_info.metadata workflow_attributes = trace_info.metadata
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
workflow_attributes["trace_id"] = trace_id workflow_attributes["trace_id"] = trace_id
workflow_attributes["start_time"] = trace_info.start workflow_attributes["start_time"] = trace_info.start_time
workflow_attributes["end_time"] = trace_info.end_time workflow_attributes["end_time"] = trace_info.end_time
workflow_attributes["tags"] = ["workflow"] workflow_attributes["tags"] = ["workflow"]
@ -158,17 +162,25 @@ class WeaveDataTrace(BaseTraceInstance):
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == "llm":
inputs = ( inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} json.loads(node_execution.process_data).get("prompts", {})
if node_execution.process_data
else {}
) )
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = (
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} json.loads(node_execution.inputs) if node_execution.inputs else {}
)
outputs = (
json.loads(node_execution.outputs) if node_execution.outputs else {}
)
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = ( execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} json.loads(node_execution.execution_metadata)
if node_execution.execution_metadata
else {}
) )
node_total_tokens = execution_metadata.get("total_tokens", 0) node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy() attributes = execution_metadata.copy()
@ -184,7 +196,11 @@ class WeaveDataTrace(BaseTraceInstance):
} }
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = (
json.loads(node_execution.process_data)
if node_execution.process_data
else {}
)
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
attributes.update( attributes.update(
{ {
@ -206,6 +222,7 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=trace_info.file_list, file_list=trace_info.file_list,
attributes=attributes, attributes=attributes,
id=node_execution_id, id=node_execution_id,
exception=None,
) )
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id) self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
@ -217,7 +234,9 @@ class WeaveDataTrace(BaseTraceInstance):
# get message file data # get message file data
file_list = cast(list[str], trace_info.file_list) or [] file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_url = (
f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
)
file_list.append(file_url) file_list.append(file_url)
attributes = trace_info.metadata attributes = trace_info.metadata
message_data = trace_info.message_data message_data = trace_info.message_data
@ -230,7 +249,9 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser)
.filter(EndUser.id == message_data.from_end_user_id)
.first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id
@ -264,6 +285,8 @@ class WeaveDataTrace(BaseTraceInstance):
inputs=trace_info.inputs, inputs=trace_info.inputs,
outputs=trace_info.outputs, outputs=trace_info.outputs,
attributes=attributes, attributes=attributes,
file_list=[],
exception=None,
) )
self.start_call( self.start_call(
llm_run, llm_run,
@ -279,8 +302,12 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata attributes = trace_info.metadata
attributes["tags"] = ["moderation"] attributes["tags"] = ["moderation"]
attributes["message_id"] = trace_info.message_id attributes["message_id"] = trace_info.message_id
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at attributes["start_time"] = (
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at trace_info.start_time or trace_info.message_data.created_at
)
attributes["end_time"] = (
trace_info.end_time or trace_info.message_data.updated_at
)
moderation_run = WeaveTraceModel( moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -293,7 +320,8 @@ class WeaveDataTrace(BaseTraceInstance):
"inputs": trace_info.inputs, "inputs": trace_info.inputs,
}, },
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=getattr(trace_info, "error", None),
file_list=[],
) )
self.start_call(moderation_run, parent_run_id=trace_info.message_id) self.start_call(moderation_run, parent_run_id=trace_info.message_id)
self.finish_call(moderation_run) self.finish_call(moderation_run)
@ -315,6 +343,7 @@ class WeaveDataTrace(BaseTraceInstance):
outputs=trace_info.suggested_question, outputs=trace_info.suggested_question,
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=trace_info.error,
file_list=[],
) )
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
@ -326,8 +355,12 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"] attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) attributes["start_time"] = (
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) trace_info.start_time or trace_info.message_data.created_at,
)
attributes["end_time"] = (
trace_info.end_time or trace_info.message_data.updated_at,
)
dataset_retrieval_run = WeaveTraceModel( dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -335,7 +368,8 @@ class WeaveDataTrace(BaseTraceInstance):
inputs=trace_info.inputs, inputs=trace_info.inputs,
outputs={"documents": trace_info.documents}, outputs={"documents": trace_info.documents},
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=getattr(trace_info, "error", None),
file_list=[],
) )
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
@ -352,11 +386,13 @@ class WeaveDataTrace(BaseTraceInstance):
op=trace_info.tool_name, op=trace_info.tool_name,
inputs=trace_info.tool_inputs, inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs, outputs=trace_info.tool_outputs,
file_list=[cast(str, trace_info.file_url)], file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=trace_info.error,
) )
message_id = trace_info.message_id or trace_info.conversation_id message_id = trace_info.message_id or getattr(
trace_info, "conversation_id", None
)
message_id = message_id or None message_id = message_id or None
self.start_call(tool_run, parent_run_id=message_id) self.start_call(tool_run, parent_run_id=message_id)
self.finish_call(tool_run) self.finish_call(tool_run)
@ -373,7 +409,7 @@ class WeaveDataTrace(BaseTraceInstance):
inputs=trace_info.inputs, inputs=trace_info.inputs,
outputs=trace_info.outputs, outputs=trace_info.outputs,
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=getattr(trace_info, "error", None),
file_list=[], file_list=[],
) )
@ -382,7 +418,9 @@ class WeaveDataTrace(BaseTraceInstance):
def api_check(self): def api_check(self):
try: try:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) login_status = wandb.login(
key=self.weave_api_key, verify=True, relogin=True
)
if not login_status: if not login_status:
raise ValueError("Weave login failed") raise ValueError("Weave login failed")
else: else:
@ -392,8 +430,12 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug(f"Weave API check failed: {str(e)}") logger.debug(f"Weave API check failed: {str(e)}")
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): def start_call(
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None
):
call = self.weave_client.create_call(
op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes
)
self.calls[run_data.id] = call self.calls[run_data.id] = call
if parent_run_id: if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id self.calls[run_data.id].parent_id = parent_run_id
@ -401,6 +443,8 @@ class WeaveDataTrace(BaseTraceInstance):
def finish_call(self, run_data: WeaveTraceModel): def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id) call = self.calls.get(run_data.id)
if call: if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) self.weave_client.finish_call(
call=call, output=run_data.outputs, exception=run_data.exception
)
else: else:
raise ValueError(f"Call with id {run_data['id']} not found") raise ValueError(f"Call with id {run_data.id} not found")

Loading…
Cancel
Save