diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 7bc458ad58..15c8dc4490 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -39,11 +39,7 @@ class WeaveDataTrace(BaseTraceInstance): self.project_name = weave_config.project self.entity = weave_config.entity self.weave_client = weave.init( - project_name=( - f"{self.entity}/{self.project_name}" - if self.entity - else self.project_name - ) + project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name) ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.calls: dict[str, Any] = {} @@ -162,25 +158,17 @@ class WeaveDataTrace(BaseTraceInstance): status = node_execution.status if node_type == "llm": inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) - if node_execution.process_data - else {} + json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} ) else: - inputs = ( - json.loads(node_execution.inputs) if node_execution.inputs else {} - ) - outputs = ( - json.loads(node_execution.outputs) if node_execution.outputs else {} - ) + inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} + outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = ( - json.loads(node_execution.execution_metadata) - if node_execution.execution_metadata - else {} + json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) attributes = execution_metadata.copy() @@ -196,11 +184,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = ( - json.loads(node_execution.process_data) - if node_execution.process_data - else {} - ) + process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { @@ -234,9 +218,7 @@ class WeaveDataTrace(BaseTraceInstance): # get message file data file_list = cast(list[str], trace_info.file_list) or [] message_file_data: Optional[MessageFile] = trace_info.message_file_data - file_url = ( - f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" - ) + file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) attributes = trace_info.metadata message_data = trace_info.message_data @@ -249,9 +231,7 @@ class WeaveDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser) - .filter(EndUser.id == message_data.from_end_user_id) - .first() + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id @@ -302,12 +282,8 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["tags"] = ["moderation"] attributes["message_id"] = trace_info.message_id - attributes["start_time"] = ( - trace_info.start_time or trace_info.message_data.created_at - ) - attributes["end_time"] = ( - trace_info.end_time or trace_info.message_data.updated_at - ) + attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at + attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -355,12 +331,8 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["message_id"] = trace_info.message_id attributes["tags"] = ["dataset_retrieval"] - attributes["start_time"] = ( - trace_info.start_time or trace_info.message_data.created_at, - ) - attributes["end_time"] = ( - trace_info.end_time or trace_info.message_data.updated_at, - ) + attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) + attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -390,9 +362,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or getattr( - trace_info, "conversation_id", None - ) + message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) message_id = message_id or None self.start_call(tool_run, parent_run_id=message_id) self.finish_call(tool_run) @@ -418,9 +388,7 @@ class WeaveDataTrace(BaseTraceInstance): def api_check(self): try: - login_status = wandb.login( - key=self.weave_api_key, verify=True, relogin=True - ) + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) if not login_status: raise ValueError("Weave login failed") else: @@ -430,12 +398,8 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}") - def start_call( - self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None - ): - call = self.weave_client.create_call( - op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes - ) + def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): + call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self.calls[run_data.id] = call if parent_run_id: self.calls[run_data.id].parent_id = parent_run_id @@ -443,8 +407,6 @@ class WeaveDataTrace(BaseTraceInstance): def finish_call(self, run_data: WeaveTraceModel): call = self.calls.get(run_data.id) if call: - self.weave_client.finish_call( - call=call, output=run_data.outputs, exception=run_data.exception - ) + self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) else: raise ValueError(f"Call with id {run_data.id} not found")