diff --git a/api/core/ops/weave_trace/__init__.py b/api/core/ops/weave_trace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/core/ops/weave_trace/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py new file mode 100644 index 0000000000..a44956186f --- /dev/null +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field +from typing import Any, Union, Optional, List, Dict + +class WeaveTokenUsage(BaseModel): + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + total_tokens: Optional[int] = None + +class WeaveMultiModel(BaseModel): + file_list: Optional[list[str]] = Field(None, description="List of files") + + + +class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): + inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace") + attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace") + +class WeaveTraceUpdateModel(BaseModel): + run_id: str = Field(..., description="ID of the run") + outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace") \ No newline at end of file diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py new file mode 100644 index 0000000000..f030f7a43f --- /dev/null +++ b/api/core/ops/weave_trace/weave_trace.py @@ -0,0 +1,72 @@ +import json +import logging +import os +import uuid +from datetime import datetime, timedelta +from typing import Optional, cast + +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.config_entity import WeaveConfig +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) + +from core.ops.utils import filter_none_values, generate_dotted_order +from extensions.ext_database import db +from models.model import EndUser, MessageFile +from models.workflow import WorkflowNodeExecution +import weave +import wandb + +logger = logging.getLogger(__name__) + + +class WeaveDataTrace(BaseTraceInstance): + def __init__( + self, + weave_config: WeaveConfig, + ): + super().__init__(weave_config) + self.weave_api_key = weave_config.api_key + self.project_name = weave_config.project + self.entity = weave_config.entity + self.weave_client = weave.init(project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name) + self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + self.calls = {} + + + def trace(self, trace_info: BaseTraceInfo): + pass + + def api_check(self): + try: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if not login_status: + raise ValueError("Weave login failed") + else: + print("Weave login successful") + return True + except Exception as e: + logger.debug(f"Weave API check failed: {str(e)}") + raise ValueError(f"Weave API check failed: {str(e)}") + + def add_run(self, run_data: dict, parent_run_id: Optional[str] = None): + call = self.weave_client.create_call(op=run_data["name"], inputs=run_data["inputs"]) + self.calls[run_data["id"]] = call + if parent_run_id: + self.calls[run_data["id"]].parent_id = parent_run_id + + def update_run(self, run_data: dict): + call = self.calls.get(run_data["id"]) + if call: + self.weave_client.finish_call(call, output=run_data["outputs"]) + else: + raise ValueError(f"Call with id {run_data['id']} not found") \ No newline at end of file