|
|
|
|
@ -3,7 +3,7 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
from typing import Optional, cast
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
import wandb
|
|
|
|
|
import weave
|
|
|
|
|
@ -39,10 +39,14 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
self.project_name = weave_config.project
|
|
|
|
|
self.entity = weave_config.entity
|
|
|
|
|
self.weave_client = weave.init(
|
|
|
|
|
project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
|
|
|
|
project_name=(
|
|
|
|
|
f"{self.entity}/{self.project_name}"
|
|
|
|
|
if self.entity
|
|
|
|
|
else self.project_name
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
|
|
|
|
self.calls = {}
|
|
|
|
|
self.calls: dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
def get_project_url(
|
|
|
|
|
self,
|
|
|
|
|
@ -103,7 +107,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
workflow_attributes = trace_info.metadata
|
|
|
|
|
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
|
|
|
|
workflow_attributes["trace_id"] = trace_id
|
|
|
|
|
workflow_attributes["start_time"] = trace_info.start
|
|
|
|
|
workflow_attributes["start_time"] = trace_info.start_time
|
|
|
|
|
workflow_attributes["end_time"] = trace_info.end_time
|
|
|
|
|
workflow_attributes["tags"] = ["workflow"]
|
|
|
|
|
|
|
|
|
|
@ -158,17 +162,25 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
status = node_execution.status
|
|
|
|
|
if node_type == "llm":
|
|
|
|
|
inputs = (
|
|
|
|
|
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
|
|
|
|
json.loads(node_execution.process_data).get("prompts", {})
|
|
|
|
|
if node_execution.process_data
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
|
|
|
|
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
|
|
|
|
inputs = (
|
|
|
|
|
json.loads(node_execution.inputs) if node_execution.inputs else {}
|
|
|
|
|
)
|
|
|
|
|
outputs = (
|
|
|
|
|
json.loads(node_execution.outputs) if node_execution.outputs else {}
|
|
|
|
|
)
|
|
|
|
|
created_at = node_execution.created_at or datetime.now()
|
|
|
|
|
elapsed_time = node_execution.elapsed_time
|
|
|
|
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
|
|
|
|
|
|
|
|
|
execution_metadata = (
|
|
|
|
|
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
|
|
|
|
json.loads(node_execution.execution_metadata)
|
|
|
|
|
if node_execution.execution_metadata
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
|
|
|
|
attributes = execution_metadata.copy()
|
|
|
|
|
@ -184,7 +196,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
|
|
|
|
process_data = (
|
|
|
|
|
json.loads(node_execution.process_data)
|
|
|
|
|
if node_execution.process_data
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
if process_data and process_data.get("model_mode") == "chat":
|
|
|
|
|
attributes.update(
|
|
|
|
|
{
|
|
|
|
|
@ -206,6 +222,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
file_list=trace_info.file_list,
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
id=node_execution_id,
|
|
|
|
|
exception=None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
|
|
|
|
|
@ -217,7 +234,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
# get message file data
|
|
|
|
|
file_list = cast(list[str], trace_info.file_list) or []
|
|
|
|
|
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
|
|
|
|
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
|
|
|
|
file_url = (
|
|
|
|
|
f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
|
|
|
|
)
|
|
|
|
|
file_list.append(file_url)
|
|
|
|
|
attributes = trace_info.metadata
|
|
|
|
|
message_data = trace_info.message_data
|
|
|
|
|
@ -230,7 +249,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
|
|
|
|
|
if message_data.from_end_user_id:
|
|
|
|
|
end_user_data: Optional[EndUser] = (
|
|
|
|
|
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
|
|
|
|
db.session.query(EndUser)
|
|
|
|
|
.filter(EndUser.id == message_data.from_end_user_id)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if end_user_data is not None:
|
|
|
|
|
end_user_id = end_user_data.session_id
|
|
|
|
|
@ -264,6 +285,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
inputs=trace_info.inputs,
|
|
|
|
|
outputs=trace_info.outputs,
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
file_list=[],
|
|
|
|
|
exception=None,
|
|
|
|
|
)
|
|
|
|
|
self.start_call(
|
|
|
|
|
llm_run,
|
|
|
|
|
@ -279,8 +302,12 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
attributes = trace_info.metadata
|
|
|
|
|
attributes["tags"] = ["moderation"]
|
|
|
|
|
attributes["message_id"] = trace_info.message_id
|
|
|
|
|
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
|
|
|
|
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
|
|
|
|
attributes["start_time"] = (
|
|
|
|
|
trace_info.start_time or trace_info.message_data.created_at
|
|
|
|
|
)
|
|
|
|
|
attributes["end_time"] = (
|
|
|
|
|
trace_info.end_time or trace_info.message_data.updated_at
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
moderation_run = WeaveTraceModel(
|
|
|
|
|
id=str(uuid.uuid4()),
|
|
|
|
|
@ -293,7 +320,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
"inputs": trace_info.inputs,
|
|
|
|
|
},
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
exception=trace_info.error,
|
|
|
|
|
exception=getattr(trace_info, "error", None),
|
|
|
|
|
file_list=[],
|
|
|
|
|
)
|
|
|
|
|
self.start_call(moderation_run, parent_run_id=trace_info.message_id)
|
|
|
|
|
self.finish_call(moderation_run)
|
|
|
|
|
@ -315,6 +343,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
outputs=trace_info.suggested_question,
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
exception=trace_info.error,
|
|
|
|
|
file_list=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
|
|
|
|
|
@ -326,8 +355,12 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
attributes = trace_info.metadata
|
|
|
|
|
attributes["message_id"] = trace_info.message_id
|
|
|
|
|
attributes["tags"] = ["dataset_retrieval"]
|
|
|
|
|
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
|
|
|
|
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
|
|
|
|
attributes["start_time"] = (
|
|
|
|
|
trace_info.start_time or trace_info.message_data.created_at,
|
|
|
|
|
)
|
|
|
|
|
attributes["end_time"] = (
|
|
|
|
|
trace_info.end_time or trace_info.message_data.updated_at,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dataset_retrieval_run = WeaveTraceModel(
|
|
|
|
|
id=str(uuid.uuid4()),
|
|
|
|
|
@ -335,7 +368,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
inputs=trace_info.inputs,
|
|
|
|
|
outputs={"documents": trace_info.documents},
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
exception=trace_info.error,
|
|
|
|
|
exception=getattr(trace_info, "error", None),
|
|
|
|
|
file_list=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
|
|
|
|
|
@ -352,11 +386,13 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
op=trace_info.tool_name,
|
|
|
|
|
inputs=trace_info.tool_inputs,
|
|
|
|
|
outputs=trace_info.tool_outputs,
|
|
|
|
|
file_list=[cast(str, trace_info.file_url)],
|
|
|
|
|
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
exception=trace_info.error,
|
|
|
|
|
)
|
|
|
|
|
message_id = trace_info.message_id or trace_info.conversation_id
|
|
|
|
|
message_id = trace_info.message_id or getattr(
|
|
|
|
|
trace_info, "conversation_id", None
|
|
|
|
|
)
|
|
|
|
|
message_id = message_id or None
|
|
|
|
|
self.start_call(tool_run, parent_run_id=message_id)
|
|
|
|
|
self.finish_call(tool_run)
|
|
|
|
|
@ -373,7 +409,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
inputs=trace_info.inputs,
|
|
|
|
|
outputs=trace_info.outputs,
|
|
|
|
|
attributes=attributes,
|
|
|
|
|
exception=trace_info.error,
|
|
|
|
|
exception=getattr(trace_info, "error", None),
|
|
|
|
|
file_list=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -382,7 +418,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
|
|
|
|
|
def api_check(self):
|
|
|
|
|
try:
|
|
|
|
|
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
|
|
|
|
login_status = wandb.login(
|
|
|
|
|
key=self.weave_api_key, verify=True, relogin=True
|
|
|
|
|
)
|
|
|
|
|
if not login_status:
|
|
|
|
|
raise ValueError("Weave login failed")
|
|
|
|
|
else:
|
|
|
|
|
@ -392,8 +430,12 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
logger.debug(f"Weave API check failed: {str(e)}")
|
|
|
|
|
raise ValueError(f"Weave API check failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
|
|
|
|
|
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
|
|
|
|
def start_call(
|
|
|
|
|
self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None
|
|
|
|
|
):
|
|
|
|
|
call = self.weave_client.create_call(
|
|
|
|
|
op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes
|
|
|
|
|
)
|
|
|
|
|
self.calls[run_data.id] = call
|
|
|
|
|
if parent_run_id:
|
|
|
|
|
self.calls[run_data.id].parent_id = parent_run_id
|
|
|
|
|
@ -401,6 +443,8 @@ class WeaveDataTrace(BaseTraceInstance):
|
|
|
|
|
def finish_call(self, run_data: WeaveTraceModel):
|
|
|
|
|
call = self.calls.get(run_data.id)
|
|
|
|
|
if call:
|
|
|
|
|
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
|
|
|
|
|
self.weave_client.finish_call(
|
|
|
|
|
call=call, output=run_data.outputs, exception=run_data.exception
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Call with id {run_data['id']} not found")
|
|
|
|
|
raise ValueError(f"Call with id {run_data.id} not found")
|
|
|
|
|
|