chore: fix ruff format errors in weave_trace

pull/14262/head
Bharat Ramanathan 1 year ago
parent e7d502cacb
commit 2b0168a5fd

@ -39,11 +39,7 @@ class WeaveDataTrace(BaseTraceInstance):
self.project_name = weave_config.project self.project_name = weave_config.project
self.entity = weave_config.entity self.entity = weave_config.entity
self.weave_client = weave.init( self.weave_client = weave.init(
project_name=( project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
f"{self.entity}/{self.project_name}"
if self.entity
else self.project_name
)
) )
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {} self.calls: dict[str, Any] = {}
@ -162,25 +158,17 @@ class WeaveDataTrace(BaseTraceInstance):
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == "llm":
inputs = ( inputs = (
json.loads(node_execution.process_data).get("prompts", {}) json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
if node_execution.process_data
else {}
) )
else: else:
inputs = ( inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
)
outputs = (
json.loads(node_execution.outputs) if node_execution.outputs else {}
)
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = ( execution_metadata = (
json.loads(node_execution.execution_metadata) json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
if node_execution.execution_metadata
else {}
) )
node_total_tokens = execution_metadata.get("total_tokens", 0) node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy() attributes = execution_metadata.copy()
@ -196,11 +184,7 @@ class WeaveDataTrace(BaseTraceInstance):
} }
) )
process_data = ( process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
json.loads(node_execution.process_data)
if node_execution.process_data
else {}
)
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
attributes.update( attributes.update(
{ {
@ -234,9 +218,7 @@ class WeaveDataTrace(BaseTraceInstance):
# get message file data # get message file data
file_list = cast(list[str], trace_info.file_list) or [] file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = ( file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
)
file_list.append(file_url) file_list.append(file_url)
attributes = trace_info.metadata attributes = trace_info.metadata
message_data = trace_info.message_data message_data = trace_info.message_data
@ -249,9 +231,7 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser) db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
.filter(EndUser.id == message_data.from_end_user_id)
.first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id
@ -302,12 +282,8 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata attributes = trace_info.metadata
attributes["tags"] = ["moderation"] attributes["tags"] = ["moderation"]
attributes["message_id"] = trace_info.message_id attributes["message_id"] = trace_info.message_id
attributes["start_time"] = ( attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
trace_info.start_time or trace_info.message_data.created_at attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
)
attributes["end_time"] = (
trace_info.end_time or trace_info.message_data.updated_at
)
moderation_run = WeaveTraceModel( moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -355,12 +331,8 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"] attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = ( attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
trace_info.start_time or trace_info.message_data.created_at, attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
)
attributes["end_time"] = (
trace_info.end_time or trace_info.message_data.updated_at,
)
dataset_retrieval_run = WeaveTraceModel( dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -390,9 +362,7 @@ class WeaveDataTrace(BaseTraceInstance):
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=trace_info.error,
) )
message_id = trace_info.message_id or getattr( message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
trace_info, "conversation_id", None
)
message_id = message_id or None message_id = message_id or None
self.start_call(tool_run, parent_run_id=message_id) self.start_call(tool_run, parent_run_id=message_id)
self.finish_call(tool_run) self.finish_call(tool_run)
@ -418,9 +388,7 @@ class WeaveDataTrace(BaseTraceInstance):
def api_check(self): def api_check(self):
try: try:
login_status = wandb.login( login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
key=self.weave_api_key, verify=True, relogin=True
)
if not login_status: if not login_status:
raise ValueError("Weave login failed") raise ValueError("Weave login failed")
else: else:
@ -430,12 +398,8 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug(f"Weave API check failed: {str(e)}") logger.debug(f"Weave API check failed: {str(e)}")
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def start_call( def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
):
call = self.weave_client.create_call(
op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes
)
self.calls[run_data.id] = call self.calls[run_data.id] = call
if parent_run_id: if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id self.calls[run_data.id].parent_id = parent_run_id
@ -443,8 +407,6 @@ class WeaveDataTrace(BaseTraceInstance):
def finish_call(self, run_data: WeaveTraceModel): def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id) call = self.calls.get(run_data.id)
if call: if call:
self.weave_client.finish_call( self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
call=call, output=run_data.outputs, exception=run_data.exception
)
else: else:
raise ValueError(f"Call with id {run_data.id} not found") raise ValueError(f"Call with id {run_data.id} not found")

Loading…
Cancel
Save