From e7d502cacbb36f2dd8a773a0c0f1b51da3bad75f Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Mon, 21 Apr 2025 15:30:19 +0530 Subject: [PATCH] chore: fix mypy errors in weave_trace --- api/core/ops/weave_trace/weave_trace.py | 94 ++++++++++++++++++------- 1 file changed, 69 insertions(+), 25 deletions(-) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index ce474f700d..7bc458ad58 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -3,7 +3,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import Any, Optional, cast import wandb import weave @@ -39,10 +39,14 @@ class WeaveDataTrace(BaseTraceInstance): self.project_name = weave_config.project self.entity = weave_config.entity self.weave_client = weave.init( - project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name + project_name=( + f"{self.entity}/{self.project_name}" + if self.entity + else self.project_name + ) ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - self.calls = {} + self.calls: dict[str, Any] = {} def get_project_url( self, @@ -103,7 +107,7 @@ class WeaveDataTrace(BaseTraceInstance): workflow_attributes = trace_info.metadata workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id workflow_attributes["trace_id"] = trace_id - workflow_attributes["start_time"] = trace_info.start + workflow_attributes["start_time"] = trace_info.start_time workflow_attributes["end_time"] = trace_info.end_time workflow_attributes["tags"] = ["workflow"] @@ -158,17 +162,25 @@ class WeaveDataTrace(BaseTraceInstance): status = node_execution.status if node_type == "llm": inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} + json.loads(node_execution.process_data).get("prompts", {}) + if node_execution.process_data + else {} ) else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = ( + json.loads(node_execution.inputs) if node_execution.inputs else {} + ) + outputs = ( + json.loads(node_execution.outputs) if node_execution.outputs else {} + ) created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + json.loads(node_execution.execution_metadata) + if node_execution.execution_metadata + else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) attributes = execution_metadata.copy() @@ -184,7 +196,11 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = ( + json.loads(node_execution.process_data) + if node_execution.process_data + else {} + ) if process_data and process_data.get("model_mode") == "chat": attributes.update( { @@ -206,6 +222,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=trace_info.file_list, attributes=attributes, id=node_execution_id, + exception=None, ) self.start_call(node_run, parent_run_id=trace_info.workflow_run_id) @@ -217,7 +234,9 @@ class WeaveDataTrace(BaseTraceInstance): # get message file data file_list = cast(list[str], trace_info.file_list) or [] message_file_data: Optional[MessageFile] = trace_info.message_file_data - file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" + file_url = ( + f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" + ) file_list.append(file_url) attributes = trace_info.metadata message_data = trace_info.message_data @@ -230,7 +249,9 @@ class WeaveDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser) + .filter(EndUser.id == message_data.from_end_user_id) + .first() ) if end_user_data is not None: end_user_id = end_user_data.session_id @@ -264,6 +285,8 @@ class WeaveDataTrace(BaseTraceInstance): inputs=trace_info.inputs, outputs=trace_info.outputs, attributes=attributes, + file_list=[], + exception=None, ) self.start_call( llm_run, @@ -279,8 +302,12 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["tags"] = ["moderation"] attributes["message_id"] = trace_info.message_id - attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at - attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at + attributes["start_time"] = ( + trace_info.start_time or trace_info.message_data.created_at + ) + attributes["end_time"] = ( + trace_info.end_time or trace_info.message_data.updated_at + ) moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -293,7 +320,8 @@ class WeaveDataTrace(BaseTraceInstance): "inputs": trace_info.inputs, }, attributes=attributes, - exception=trace_info.error, + exception=getattr(trace_info, "error", None), + file_list=[], ) self.start_call(moderation_run, parent_run_id=trace_info.message_id) self.finish_call(moderation_run) @@ -315,6 +343,7 @@ class WeaveDataTrace(BaseTraceInstance): outputs=trace_info.suggested_question, attributes=attributes, exception=trace_info.error, + file_list=[], ) self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) @@ -326,8 +355,12 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["message_id"] = trace_info.message_id attributes["tags"] = ["dataset_retrieval"] - attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) - attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) + attributes["start_time"] = ( + trace_info.start_time or trace_info.message_data.created_at, + ) + attributes["end_time"] = ( + trace_info.end_time or trace_info.message_data.updated_at, + ) dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -335,7 +368,8 @@ class WeaveDataTrace(BaseTraceInstance): inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, attributes=attributes, - exception=trace_info.error, + exception=getattr(trace_info, "error", None), + file_list=[], ) self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) @@ -352,11 +386,13 @@ class WeaveDataTrace(BaseTraceInstance): op=trace_info.tool_name, inputs=trace_info.tool_inputs, outputs=trace_info.tool_outputs, - file_list=[cast(str, trace_info.file_url)], + file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [], attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or trace_info.conversation_id + message_id = trace_info.message_id or getattr( + trace_info, "conversation_id", None + ) message_id = message_id or None self.start_call(tool_run, parent_run_id=message_id) self.finish_call(tool_run) @@ -373,7 +409,7 @@ class WeaveDataTrace(BaseTraceInstance): inputs=trace_info.inputs, outputs=trace_info.outputs, attributes=attributes, - exception=trace_info.error, + exception=getattr(trace_info, "error", None), file_list=[], ) @@ -382,7 +418,9 @@ class WeaveDataTrace(BaseTraceInstance): def api_check(self): try: - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + login_status = wandb.login( + key=self.weave_api_key, verify=True, relogin=True + ) if not login_status: raise ValueError("Weave login failed") else: @@ -392,8 +430,12 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}") - def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): - call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) + def start_call( + self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None + ): + call = self.weave_client.create_call( + op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes + ) self.calls[run_data.id] = call if parent_run_id: self.calls[run_data.id].parent_id = parent_run_id @@ -401,6 +443,8 @@ class WeaveDataTrace(BaseTraceInstance): def finish_call(self, run_data: WeaveTraceModel): call = self.calls.get(run_data.id) if call: - self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) + self.weave_client.finish_call( + call=call, output=run_data.outputs, exception=run_data.exception + ) else: - raise ValueError(f"Call with id {run_data['id']} not found") + raise ValueError(f"Call with id {run_data.id} not found")