chore: run linting and formatting

pull/14262/head
Bharat Ramanathan 1 year ago
parent abfc76b69f
commit 88ad5fcdfe

@ -88,6 +88,7 @@ class OpikConfig(BaseTracingConfig):
return v return v
class WeaveConfig(BaseTracingConfig): class WeaveConfig(BaseTracingConfig):
""" """
Model class for Weave tracing config. Model class for Weave tracing config.

@ -18,8 +18,8 @@ from core.ops.entities.config_entity import (
LangfuseConfig, LangfuseConfig,
LangSmithConfig, LangSmithConfig,
OpikConfig, OpikConfig,
WeaveConfig,
TracingProviderEnum, TracingProviderEnum,
WeaveConfig,
) )
from core.ops.entities.trace_entity import ( from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo, DatasetRetrievalTraceInfo,
@ -34,9 +34,9 @@ from core.ops.entities.trace_entity import (
) )
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data from core.ops.utils import get_message_data
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig

@ -1,25 +1,29 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from typing import Any, Union, Optional, List, Dict
from core.ops.utils import replace_text_with_content from core.ops.utils import replace_text_with_content
class WeaveTokenUsage(BaseModel): class WeaveTokenUsage(BaseModel):
input_tokens: Optional[int] = None input_tokens: Optional[int] = None
output_tokens: Optional[int] = None output_tokens: Optional[int] = None
total_tokens: Optional[int] = None total_tokens: Optional[int] = None
class WeaveMultiModel(BaseModel): class WeaveMultiModel(BaseModel):
file_list: Optional[list[str]] = Field(None, description="List of files") file_list: Optional[list[str]] = Field(None, description="List of files")
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace") id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation") op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace") inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace") outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace") attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(
None, description="Metadata and attributes associated with trace"
)
exception: Optional[str] = Field(None, description="Exception message of the trace") exception: Optional[str] = Field(None, description="Exception message of the trace")
@field_validator("inputs", "outputs") @field_validator("inputs", "outputs")

@ -5,6 +5,9 @@ import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, cast from typing import Optional, cast
import wandb
import weave
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import ( from core.ops.entities.trace_entity import (
@ -18,14 +21,10 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.ops.utils import filter_none_values, generate_dotted_order
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution from models.workflow import WorkflowNodeExecution
import weave
import wandb
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,11 +38,15 @@ class WeaveDataTrace(BaseTraceInstance):
self.weave_api_key = weave_config.api_key self.weave_api_key = weave_config.api_key
self.project_name = weave_config.project self.project_name = weave_config.project
self.entity = weave_config.entity self.entity = weave_config.entity
self.weave_client = weave.init(project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name) self.weave_client = weave.init(
project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls = {} self.calls = {}
def get_project_url(self,): def get_project_url(
self,
):
try: try:
project_url = f"https://wandb.ai/{self.weave_client._project_id()}" project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
return project_url return project_url
@ -51,36 +54,163 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug(f"Weave get run url failed: {str(e)}") logger.debug(f"Weave get run url failed: {str(e)}")
raise ValueError(f"Weave get run url failed: {str(e)}") raise ValueError(f"Weave get run url failed: {str(e)}")
def trace(self, trace_info: BaseTraceInfo): def trace(self, trace_info: BaseTraceInfo):
logger.debug(f"Trace info: {trace_info}") logger.debug(f"Trace info: {trace_info}")
print("Trace info: ", trace_info)
if isinstance(trace_info, WorkflowTraceInfo): if isinstance(trace_info, WorkflowTraceInfo):
# self.workflow_trace(trace_info) self.workflow_trace(trace_info)
print("Workflow trace: ", trace_info)
pass
if isinstance(trace_info, MessageTraceInfo): if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info) self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo): if isinstance(trace_info, ModerationTraceInfo):
print("Moderation trace: ", trace_info) self.moderation_trace(trace_info)
pass
# self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo): if isinstance(trace_info, SuggestedQuestionTraceInfo):
print("Suggested question trace: ", trace_info) self.suggested_question_trace(trace_info)
pass
# self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo): if isinstance(trace_info, DatasetRetrievalTraceInfo):
print("Dataset retrieval trace: ", trace_info) self.dataset_retrieval_trace(trace_info)
pass
# self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo): if isinstance(trace_info, ToolTraceInfo):
print("Tool trace: ", trace_info) self.tool_trace(trace_info)
pass
# self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo): if isinstance(trace_info, GenerateNameTraceInfo):
print("Generate name trace: ", trace_info) self.generate_name_trace(trace_info)
pass
# self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
if trace_info.message_id:
message_attributes = trace_info.metadata
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
message_attributes["message_id"] = trace_info.message_id
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
message_attributes["trace_id"] = trace_id
message_attributes["start_time"] = trace_info.start_time
message_attributes["end_time"] = trace_info.end_time
message_attributes["tags"] = ["message", "workflow"]
message_run = WeaveTraceModel(
id=trace_info.message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=message_attributes,
exception=trace_info.error,
file_list=[],
)
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(message_run)
workflow_attributes = trace_info.metadata
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
workflow_attributes["trace_id"] = trace_id
workflow_attributes["start_time"] = trace_info.start
workflow_attributes["end_time"] = trace_info.end_time
workflow_attributes["tags"] = ["workflow"]
workflow_run = WeaveTraceModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
op=str(TraceTaskName.WORKFLOW_TRACE.value),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=workflow_attributes,
exception=trace_info.error,
)
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy()
attributes.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
attributes.update(
{
"ls_provider": process_data.get("model_provider", ""),
"ls_model_name": process_data.get("model_name", ""),
}
)
attributes["tags"] = ["node_execution"]
attributes["start_time"] = created_at
attributes["end_time"] = finished_at
attributes["elapsed_time"] = elapsed_time
attributes["workflow_run_id"] = trace_info.workflow_run_id
attributes["trace_id"] = trace_id
node_run = WeaveTraceModel(
total_tokens=node_total_tokens,
op=node_type,
inputs=inputs,
outputs=outputs,
file_list=trace_info.file_list,
attributes=attributes,
id=node_execution_id,
)
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(node_run)
self.finish_call(workflow_run)
def message_trace(self, trace_info: MessageTraceInfo): def message_trace(self, trace_info: MessageTraceInfo):
# get message file data # get message file data
@ -88,14 +218,14 @@ class WeaveDataTrace(BaseTraceInstance):
message_file_data: Optional[MessageFile] = trace_info.message_file_data message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url) file_list.append(file_url)
metadata = trace_info.metadata attributes = trace_info.metadata
message_data = trace_info.message_data message_data = trace_info.message_data
if message_data is None: if message_data is None:
return return
message_id = message_data.id message_id = message_data.id
user_id = message_data.from_account_id user_id = message_data.from_account_id
metadata["user_id"] = user_id attributes["user_id"] = user_id
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
@ -103,12 +233,12 @@ class WeaveDataTrace(BaseTraceInstance):
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id attributes["end_user_id"] = end_user_id
metadata["message_id"] = message_id attributes["message_id"] = message_id
metadata["start_time"]=trace_info.start_time attributes["start_time"] = trace_info.start_time
metadata["end_time"]=trace_info.end_time attributes["end_time"] = trace_info.end_time
metadata["tags"] = ["message", str(trace_info.conversation_mode)] attributes["tags"] = ["message", str(trace_info.conversation_mode)]
message_run = WeaveTraceModel( message_run = WeaveTraceModel(
id=message_id, id=message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value), op=str(TraceTaskName.MESSAGE_TRACE.value),
@ -119,9 +249,9 @@ class WeaveDataTrace(BaseTraceInstance):
outputs=trace_info.outputs, outputs=trace_info.outputs,
exception=trace_info.error, exception=trace_info.error,
file_list=file_list, file_list=file_list,
attributes=metadata attributes=attributes,
) )
self.add_run(message_run) self.start_call(message_run)
# create llm run parented to message run # create llm run parented to message run
llm_run = WeaveTraceModel( llm_run = WeaveTraceModel(
@ -132,20 +262,24 @@ class WeaveDataTrace(BaseTraceInstance):
op="llm", op="llm",
inputs=trace_info.inputs, inputs=trace_info.inputs,
outputs=trace_info.outputs, outputs=trace_info.outputs,
attributes=metadata, attributes=attributes,
)
self.start_call(
llm_run,
parent_run_id=message_id,
) )
self.add_run(llm_run, parent_run_id=message_id,) self.finish_call(llm_run)
self.update_run(llm_run) self.finish_call(message_run)
self.update_run(message_run)
def moderation_trace(self, trace_info: ModerationTraceInfo): def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None: if trace_info.message_data is None:
return return
metadata = trace_info.metadata attributes = trace_info.metadata
metadata["tags"] = ["moderation"] attributes["tags"] = ["moderation"]
metadata["start_time"] = trace_info.start_time or trace_info.message_data.created_at, attributes["message_id"] = trace_info.message_id
metadata["end_time"] = trace_info.end_time or trace_info.message_data.updated_at, attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
moderation_run = WeaveTraceModel( moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -157,9 +291,91 @@ class WeaveDataTrace(BaseTraceInstance):
"preset_response": trace_info.preset_response, "preset_response": trace_info.preset_response,
"inputs": trace_info.inputs, "inputs": trace_info.inputs,
}, },
attributes=metadata, attributes=attributes,
exception=trace_info.error,
)
self.start_call(moderation_run, parent_run_id=trace_info.message_id)
self.finish_call(moderation_run)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["suggested_question"]
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
attributes=attributes,
exception=trace_info.error,
)
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
self.finish_call(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
attributes=attributes,
exception=trace_info.error,
)
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
self.finish_call(dataset_retrieval_run)
def tool_trace(self, trace_info: ToolTraceInfo):
attributes = trace_info.metadata
attributes["tags"] = ["tool", trace_info.tool_name]
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
tool_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=trace_info.tool_name,
inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs,
file_list=[cast(str, trace_info.file_url)],
attributes=attributes,
exception=trace_info.error,
) )
self.add_run(moderation_run, parent_run_id=trace_info.message_id)
self.start_call(tool_run, parent_run_id=trace_info.message_id)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
attributes = trace_info.metadata
attributes["tags"] = ["generate_name"]
attributes["start_time"] = trace_info.start_time
attributes["end_time"] = trace_info.end_time
name_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,
exception=trace_info.error,
file_list=[],
)
self.start_call(name_run)
self.finish_call(name_run)
def api_check(self): def api_check(self):
try: try:
@ -173,13 +389,13 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug(f"Weave API check failed: {str(e)}") logger.debug(f"Weave API check failed: {str(e)}")
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def add_run(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
self.calls[run_data.id] = call self.calls[run_data.id] = call
if parent_run_id: if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id self.calls[run_data.id].parent_id = parent_run_id
def update_run(self, run_data: WeaveTraceModel): def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id) call = self.calls.get(run_data.id)
if call: if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)

Loading…
Cancel
Save