stash: add tracing implementation scaffold in the api
parent
79964f0b32
commit
6b76ef1768
@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Union, Optional, List, Dict
|
||||
|
||||
class WeaveTokenUsage(BaseModel):
|
||||
input_tokens: Optional[int] = None
|
||||
output_tokens: Optional[int] = None
|
||||
total_tokens: Optional[int] = None
|
||||
|
||||
class WeaveMultiModel(BaseModel):
|
||||
file_list: Optional[list[str]] = Field(None, description="List of files")
|
||||
|
||||
|
||||
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace")
|
||||
attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace")
|
||||
|
||||
class WeaveTraceUpdateModel(BaseModel):
|
||||
run_id: str = Field(..., description="ID of the run")
|
||||
outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace")
|
||||
@ -0,0 +1,72 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
import weave
|
||||
import wandb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeaveDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
weave_config: WeaveConfig,
|
||||
):
|
||||
super().__init__(weave_config)
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.weave_client = weave.init(project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls = {}
|
||||
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
print("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Weave API check failed: {str(e)}")
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def add_run(self, run_data: dict, parent_run_id: Optional[str] = None):
|
||||
call = self.weave_client.create_call(op=run_data["name"], inputs=run_data["inputs"])
|
||||
self.calls[run_data["id"]] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data["id"]].parent_id = parent_run_id
|
||||
|
||||
def update_run(self, run_data: dict):
|
||||
call = self.calls.get(run_data["id"])
|
||||
if call:
|
||||
self.weave_client.finish_call(call, output=run_data["outputs"])
|
||||
else:
|
||||
raise ValueError(f"Call with id {run_data['id']} not found")
|
||||
Loading…
Reference in New Issue