stash: add tracing implementation scaffold in the api

pull/14262/head
Bharat Ramanathan 1 year ago
parent d92f57b374
commit 8a741afaae

@ -1,6 +1,9 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from typing import Any, Union, Optional, List, Dict from typing import Any, Union, Optional, List, Dict
from core.ops.utils import replace_text_with_content
class WeaveTokenUsage(BaseModel): class WeaveTokenUsage(BaseModel):
input_tokens: Optional[int] = None input_tokens: Optional[int] = None
output_tokens: Optional[int] = None output_tokens: Optional[int] = None
@ -12,9 +15,75 @@ class WeaveMultiModel(BaseModel):
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace") inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace") attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace")
exception: Optional[str] = Field(None, description="Exception message of the trace")
class WeaveTraceUpdateModel(BaseModel): @field_validator("inputs", "outputs")
run_id: str = Field(..., description="ID of the run") @classmethod
outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace") def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
if v == {} or v is None:
return v
usage_metadata = {
"input_tokens": values.get("input_tokens", 0),
"output_tokens": values.get("output_tokens", 0),
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list
return v
return v

@ -25,6 +25,7 @@ from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution from models.workflow import WorkflowNodeExecution
import weave import weave
import wandb import wandb
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,9 +43,100 @@ class WeaveDataTrace(BaseTraceInstance):
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls = {} self.calls = {}
def get_project_url(self,):
try:
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
return project_url
except Exception as e:
logger.debug(f"Weave get run url failed: {str(e)}")
raise ValueError(f"Weave get run url failed: {str(e)}")
def trace(self, trace_info: BaseTraceInfo): def trace(self, trace_info: BaseTraceInfo):
logger.debug(f"Trace info: {trace_info}")
print("Trace info: ", trace_info)
if isinstance(trace_info, WorkflowTraceInfo):
# self.workflow_trace(trace_info)
print("Workflow trace: ", trace_info)
pass
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
print("Moderation trace: ", trace_info)
pass
# self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
print("Suggested question trace: ", trace_info)
pass
# self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
print("Dataset retrieval trace: ", trace_info)
pass
# self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
print("Tool trace: ", trace_info)
pass pass
# self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
print("Generate name trace: ", trace_info)
pass
# self.generate_name_trace(trace_info)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id
metadata["message_id"] = message_id
metadata["start_time"]=trace_info.start_time
metadata["end_time"]=trace_info.end_time
metadata["tags"] = ["message", str(trace_info.conversation_mode)]
message_run = WeaveTraceModel(
id=message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
exception=trace_info.error,
file_list=file_list,
attributes=metadata
)
self.add_run(message_run)
# create llm run parented to message run
llm_run = WeaveTraceModel(
id=str(uuid.uuid4()),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
op="llm",
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=metadata,
)
self.add_run(llm_run, parent_run_id=message_id,)
self.update_run(llm_run)
self.update_run(message_run)
def api_check(self): def api_check(self):
try: try:
@ -58,15 +150,15 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug(f"Weave API check failed: {str(e)}") logger.debug(f"Weave API check failed: {str(e)}")
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def add_run(self, run_data: dict, parent_run_id: Optional[str] = None): def add_run(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
call = self.weave_client.create_call(op=run_data["name"], inputs=run_data["inputs"]) call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
self.calls[run_data["id"]] = call self.calls[run_data.id] = call
if parent_run_id: if parent_run_id:
self.calls[run_data["id"]].parent_id = parent_run_id self.calls[run_data.id].parent_id = parent_run_id
def update_run(self, run_data: dict): def update_run(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data["id"]) call = self.calls.get(run_data.id)
if call: if call:
self.weave_client.finish_call(call, output=run_data["outputs"]) self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
else: else:
raise ValueError(f"Call with id {run_data['id']} not found") raise ValueError(f"Call with id {run_data['id']} not found")

@ -67,7 +67,14 @@ class OpsService:
new_decrypt_tracing_config.update({"project_url": project_url}) new_decrypt_tracing_config.update({"project_url": project_url})
except Exception: except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"}) new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
if tracing_provider == "weave" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
trace_config_data.tracing_config = new_decrypt_tracing_config trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict() return trace_config_data.to_dict()

Loading…
Cancel
Save