diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index a44956186f..f067b5216c 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,6 +1,9 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator +from pydantic_core.core_schema import ValidationInfo from typing import Any, Union, Optional, List, Dict +from core.ops.utils import replace_text_with_content + class WeaveTokenUsage(BaseModel): input_tokens: Optional[int] = None output_tokens: Optional[int] = None @@ -12,9 +15,75 @@ class WeaveMultiModel(BaseModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): + id: str = Field(..., description="ID of the trace") + op: str = Field(..., description="Name of the operation") inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace") + outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace") attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace") + exception: Optional[str] = Field(None, description="Exception message of the trace") -class WeaveTraceUpdateModel(BaseModel): - run_id: str = Field(..., description="ID of the run") - outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace") \ No newline at end of file + @field_validator("inputs", "outputs") + @classmethod + def ensure_dict(cls, v, info: ValidationInfo): + field_name = info.field_name + values = info.data + if v == {} or v is None: + return v + usage_metadata = { + "input_tokens": values.get("input_tokens", 0), + "output_tokens": values.get("output_tokens", 0), + "total_tokens": values.get("total_tokens", 0), + } + file_list = values.get("file_list", []) + if isinstance(v, str): + if field_name == "inputs": + return { + "messages": { + "role": "user", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + elif field_name == "outputs": + return { + "choices": { + "role": "ai", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + elif isinstance(v, list): + data = {} + if len(v) > 0 and isinstance(v[0], dict): + # rename text to content + v = replace_text_with_content(data=v) + if field_name == "inputs": + data = { + "messages": v, + } + elif field_name == "outputs": + data = { + "choices": { + "role": "ai", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + return data + else: + return { + "choices": { + "role": "ai" if field_name == "outputs" else "user", + "content": str(v), + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + if isinstance(v, dict): + v["usage_metadata"] = usage_metadata + v["file_list"] = file_list + return v + return v \ No newline at end of file diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index f030f7a43f..ccd21be9de 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -25,6 +25,7 @@ from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecution import weave import wandb +from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel logger = logging.getLogger(__name__) @@ -42,9 +43,100 @@ class WeaveDataTrace(BaseTraceInstance): self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.calls = {} + def get_project_url(self,): + try: + project_url = f"https://wandb.ai/{self.weave_client._project_id()}" + return project_url + except Exception as e: + logger.debug(f"Weave get run url failed: {str(e)}") + raise ValueError(f"Weave get run url failed: {str(e)}") + def trace(self, trace_info: BaseTraceInfo): - pass + logger.debug(f"Trace info: {trace_info}") + print("Trace info: ", trace_info) + if isinstance(trace_info, WorkflowTraceInfo): + # self.workflow_trace(trace_info) + print("Workflow trace: ", trace_info) + pass + if isinstance(trace_info, MessageTraceInfo): + self.message_trace(trace_info) + if isinstance(trace_info, ModerationTraceInfo): + print("Moderation trace: ", trace_info) + pass + # self.moderation_trace(trace_info) + if isinstance(trace_info, SuggestedQuestionTraceInfo): + print("Suggested question trace: ", trace_info) + pass + # self.suggested_question_trace(trace_info) + if isinstance(trace_info, DatasetRetrievalTraceInfo): + print("Dataset retrieval trace: ", trace_info) + pass + # self.dataset_retrieval_trace(trace_info) + if isinstance(trace_info, ToolTraceInfo): + print("Tool trace: ", trace_info) + pass + # self.tool_trace(trace_info) + if isinstance(trace_info, GenerateNameTraceInfo): + print("Generate name trace: ", trace_info) + pass + # self.generate_name_trace(trace_info) + + def message_trace(self, trace_info: MessageTraceInfo): + # get message file data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data + file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" + file_list.append(file_url) + metadata = trace_info.metadata + message_data = trace_info.message_data + if message_data is None: + return + message_id = message_data.id + + user_id = message_data.from_account_id + metadata["user_id"] = user_id + + if message_data.from_end_user_id: + end_user_data: Optional[EndUser] = ( + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + ) + if end_user_data is not None: + end_user_id = end_user_data.session_id + metadata["end_user_id"] = end_user_id + + metadata["message_id"] = message_id + metadata["start_time"]=trace_info.start_time + metadata["end_time"]=trace_info.end_time + metadata["tags"] = ["message", str(trace_info.conversation_mode)] + message_run = WeaveTraceModel( + id=message_id, + op=str(TraceTaskName.MESSAGE_TRACE.value), + input_tokens=trace_info.message_tokens, + output_tokens=trace_info.answer_tokens, + total_tokens=trace_info.total_tokens, + inputs=trace_info.inputs, + outputs=trace_info.outputs, + exception=trace_info.error, + file_list=file_list, + attributes=metadata + ) + self.add_run(message_run) + + # create llm run parented to message run + llm_run = WeaveTraceModel( + id=str(uuid.uuid4()), + input_tokens=trace_info.message_tokens, + output_tokens=trace_info.answer_tokens, + total_tokens=trace_info.total_tokens, + op="llm", + inputs=trace_info.inputs, + outputs=trace_info.outputs, + attributes=metadata, + ) + self.add_run(llm_run, parent_run_id=message_id,) + self.update_run(llm_run) + self.update_run(message_run) def api_check(self): try: @@ -58,15 +150,15 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}") - def add_run(self, run_data: dict, parent_run_id: Optional[str] = None): - call = self.weave_client.create_call(op=run_data["name"], inputs=run_data["inputs"]) - self.calls[run_data["id"]] = call + def add_run(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): + call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) + self.calls[run_data.id] = call if parent_run_id: - self.calls[run_data["id"]].parent_id = parent_run_id + self.calls[run_data.id].parent_id = parent_run_id - def update_run(self, run_data: dict): - call = self.calls.get(run_data["id"]) + def update_run(self, run_data: WeaveTraceModel): + call = self.calls.get(run_data.id) if call: - self.weave_client.finish_call(call, output=run_data["outputs"]) + self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) else: raise ValueError(f"Call with id {run_data['id']} not found") \ No newline at end of file diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 78340d2bcc..cea62325a7 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -67,7 +67,14 @@ class OpsService: new_decrypt_tracing_config.update({"project_url": project_url}) except Exception: new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"}) - + if tracing_provider == "weave" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"}) trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict()