chore: run dev/reformat and fix lint issues

pull/20765/head
Bharat Ramanathan 12 months ago
parent dd28369cdf
commit d1243b0a01

@ -120,12 +120,10 @@ class OpsTraceManager:
if key in tracing_config: if key in tracing_config:
if "*" in tracing_config[key]: if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config # If the key contains '*', retain the original value from the current config
new_config[key] = current_trace_config.get( new_config[key] = current_trace_config.get(key, tracing_config[key])
key, tracing_config[key])
else: else:
# Otherwise, encrypt the key # Otherwise, encrypt the key
new_config[key] = encrypt_token( new_config[key] = encrypt_token(tenant_id, tracing_config[key])
tenant_id, tracing_config[key])
for key in other_keys: for key in other_keys:
new_config[key] = tracing_config.get(key, "") new_config[key] = tracing_config.get(key, "")
@ -225,8 +223,7 @@ class OpsTraceManager:
if app_id is None: if app_id is None:
return None return None
app: Optional[App] = db.session.query( app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
App).filter(App.id == app_id).first()
if app is None: if app is None:
return None return None
@ -246,8 +243,7 @@ class OpsTraceManager:
return None return None
# decrypt_token # decrypt_token
decrypt_trace_config = cls.get_decrypted_tracing_config( decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
app_id, tracing_provider)
if not decrypt_trace_config: if not decrypt_trace_config:
return None return None
@ -256,12 +252,10 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["config_class"],
) )
decrypt_trace_config_key = str(decrypt_trace_config) decrypt_trace_config_key = str(decrypt_trace_config)
tracing_instance = cls.ops_trace_instances_cache.get( tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
decrypt_trace_config_key)
if tracing_instance is None: if tracing_instance is None:
# create new tracing_instance and update the cache if it absent # create new tracing_instance and update the cache if it absent
tracing_instance = trace_instance( tracing_instance = trace_instance(config_class(**decrypt_trace_config))
config_class(**decrypt_trace_config))
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
logging.info(f"new tracing_instance for app_id: {app_id}") logging.info(f"new tracing_instance for app_id: {app_id}")
return tracing_instance return tracing_instance
@ -269,13 +263,11 @@ class OpsTraceManager:
@classmethod @classmethod
def get_app_config_through_message_id(cls, message_id: str): def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None app_model_config = None
message_data = db.session.query(Message).filter( message_data = db.session.query(Message).filter(Message.id == message_id).first()
Message.id == message_id).first()
if not message_data: if not message_data:
return None return None
conversation_id = message_data.conversation_id conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter( conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
Conversation.id == conversation_id).first()
if not conversation_data: if not conversation_data:
return None return None
@ -304,15 +296,12 @@ class OpsTraceManager:
try: try:
provider_config_map[tracing_provider] provider_config_map[tracing_provider]
except KeyError: except KeyError:
raise ValueError( raise ValueError(f"Invalid tracing provider: {tracing_provider}")
f"Invalid tracing provider: {tracing_provider}")
else: else:
if tracing_provider is not None: if tracing_provider is not None:
raise ValueError( raise ValueError(f"Invalid tracing provider: {tracing_provider}")
f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query( app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
App).filter(App.id == app_id).first()
if not app_config: if not app_config:
raise ValueError("App not found") raise ValueError("App not found")
app_config.tracing = json.dumps( app_config.tracing = json.dumps(
@ -330,8 +319,7 @@ class OpsTraceManager:
:param app_id: app id :param app_id: app id
:return: :return:
""" """
app: Optional[App] = db.session.query( app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
App).filter(App.id == app_id).first()
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")
if not app.tracing: if not app.tracing:
@ -451,8 +439,7 @@ class TraceTask:
return {} return {}
with Session(db.engine) as session: with Session(db.engine) as session:
workflow_run_stmt = select(WorkflowRun).where( workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
WorkflowRun.id == workflow_run_id)
workflow_run = session.scalars(workflow_run_stmt).first() workflow_run = session.scalars(workflow_run_stmt).first()
if not workflow_run: if not workflow_run:
raise ValueError("Workflow run not found") raise ValueError("Workflow run not found")
@ -470,8 +457,7 @@ class TraceTask:
total_tokens = workflow_run.total_tokens total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or [] file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get( query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
"query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id # get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
@ -533,8 +519,7 @@ class TraceTask:
message_data = get_message_data(message_id) message_data = get_message_data(message_id)
if not message_data: if not message_data:
return {} return {}
conversation_mode_stmt = select(Conversation.mode).where( conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all() conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0: if not conversation_mode or len(conversation_mode) == 0:
return {} return {}
@ -543,8 +528,7 @@ class TraceTask:
inputs = message_data.message inputs = message_data.message
# get message file data # get message file data
message_file_data = db.session.query( message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
MessageFile).filter_by(message_id=message_id).first()
file_list = [] file_list = []
if message_file_data and message_file_data.url is not None: if message_file_data and message_file_data.url is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
@ -577,8 +561,7 @@ class TraceTask:
outputs=message_data.answer, outputs=message_data.answer,
file_list=file_list, file_list=file_list,
start_time=created_at, start_time=created_at,
end_time=created_at + end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
timedelta(seconds=message_data.provider_response_latency),
metadata=metadata, metadata=metadata,
message_file_data=message_file_data, message_file_data=message_file_data,
conversation_mode=conversation_mode, conversation_mode=conversation_mode,
@ -605,11 +588,9 @@ class TraceTask:
workflow_app_log_id = None workflow_app_log_id = None
if message_data.workflow_run_id: if message_data.workflow_run_id:
workflow_app_log_data = ( workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by( db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
workflow_run_id=message_data.workflow_run_id).first()
) )
workflow_app_log_id = str( workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
workflow_app_log_data.id) if workflow_app_log_data else None
moderation_trace_info = ModerationTraceInfo( moderation_trace_info = ModerationTraceInfo(
message_id=workflow_app_log_id or message_id, message_id=workflow_app_log_id or message_id,
@ -647,11 +628,9 @@ class TraceTask:
workflow_app_log_id = None workflow_app_log_id = None
if message_data.workflow_run_id: if message_data.workflow_run_id:
workflow_app_log_data = ( workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by( db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
workflow_run_id=message_data.workflow_run_id).first()
) )
workflow_app_log_id = str( workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
workflow_app_log_data.id) if workflow_app_log_data else None
suggested_question_trace_info = SuggestedQuestionTraceInfo( suggested_question_trace_info = SuggestedQuestionTraceInfo(
message_id=workflow_app_log_id or message_id, message_id=workflow_app_log_id or message_id,
@ -697,8 +676,7 @@ class TraceTask:
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id, message_id=message_id,
inputs=message_data.query or message_data.inputs, inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() documents=[doc.model_dump() for doc in documents] if documents else [],
for doc in documents] if documents else [],
start_time=timer.get("start"), start_time=timer.get("start"),
end_time=timer.get("end"), end_time=timer.get("end"),
metadata=metadata, metadata=metadata,
@ -742,8 +720,7 @@ class TraceTask:
} }
file_url = "" file_url = ""
message_file_data = db.session.query( message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
MessageFile).filter_by(message_id=message_id).first()
if message_file_data: if message_file_data:
message_file_id = message_file_data.id if message_file_data else None message_file_id = message_file_data.id if message_file_data else None
type = message_file_data.type type = message_file_data.type
@ -811,8 +788,7 @@ class TraceTask:
trace_manager_timer: Optional[threading.Timer] = None trace_manager_timer: Optional[threading.Timer] = None
trace_manager_queue: queue.Queue = queue.Queue() trace_manager_queue: queue.Queue = queue.Queue()
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
trace_manager_batch_size = int( trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
class TraceQueueManager: class TraceQueueManager:
@ -833,8 +809,7 @@ class TraceQueueManager:
trace_task.app_id = self.app_id trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task) trace_manager_queue.put(trace_task)
except Exception as e: except Exception as e:
logging.exception( logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
f"Error adding trace task, trace_type {trace_task.trace_type}")
finally: finally:
self.start_timer() self.start_timer()
@ -858,8 +833,7 @@ class TraceQueueManager:
def start_timer(self): def start_timer(self):
global trace_manager_timer global trace_manager_timer
if trace_manager_timer is None or not trace_manager_timer.is_alive(): if trace_manager_timer is None or not trace_manager_timer.is_alive():
trace_manager_timer = threading.Timer( trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
trace_manager_interval, self.run)
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
trace_manager_timer.daemon = False trace_manager_timer.daemon = False
trace_manager_timer.start() trace_manager_timer.start()
@ -877,8 +851,7 @@ class TraceQueueManager:
trace_info=trace_info.model_dump() if trace_info else None, trace_info=trace_info.model_dump() if trace_info else None,
) )
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
storage.save( storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_path, task_data.model_dump_json().encode("utf-8"))
file_info = { file_info = {
"file_id": file_id, "file_id": file_id,
"app_id": task.app_id, "app_id": task.app_id,

@ -4,10 +4,10 @@ import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Optional, cast from typing import Any, Optional, cast
import wandb
import weave import weave
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import wandb
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import ( from core.ops.entities.trace_entity import (
@ -43,23 +43,18 @@ class WeaveDataTrace(BaseTraceInstance):
self.host = weave_config.host self.host = weave_config.host
# Login with API key first, including host if provided # Login with API key first, including host if provided
login_kwargs = {
"key": self.weave_api_key,
"verify": True,
"relogin": True,
}
if self.host: if self.host:
login_kwargs["host"] = self.host login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
login_status = wandb.login(**login_kwargs) else:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status: if not login_status:
logger.error( logger.error("Failed to login to Weights & Biases with the provided API key")
"Failed to login to Weights & Biases with the provided API key")
raise ValueError("Weave login failed") raise ValueError("Weave login failed")
# Then initialize weave client # Then initialize weave client
self.weave_client = weave.init( self.weave_client = weave.init(
project_name=( project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
) )
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {} self.calls: dict[str, Any] = {}
@ -117,8 +112,7 @@ class WeaveDataTrace(BaseTraceInstance):
exception=trace_info.error, exception=trace_info.error,
file_list=[], file_list=[],
) )
self.start_call( self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
message_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(message_run) self.finish_call(message_run)
workflow_attributes = trace_info.metadata workflow_attributes = trace_info.metadata
@ -165,14 +159,12 @@ class WeaveDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions: for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = trace_info.tenant_id # Use from trace_info instead tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get( app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
"app_id") # Use from trace_info instead
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == NodeType.LLM: if node_type == NodeType.LLM:
inputs = node_execution.process_data.get( inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
"prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {} outputs = node_execution.outputs if node_execution.outputs else {}
@ -181,8 +173,7 @@ class WeaveDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get( node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()} attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update( attributes.update(
{ {
@ -243,8 +234,7 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id
@ -322,10 +312,8 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["suggested_question"] attributes["tags"] = ["suggested_question"]
attributes["start_time"] = ( attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
trace_info.start_time or message_data.created_at,) attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
attributes["end_time"] = (
trace_info.end_time or message_data.updated_at,)
suggested_question_run = WeaveTraceModel( suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -337,8 +325,7 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=[], file_list=[],
) )
self.start_call(suggested_question_run, self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
parent_run_id=trace_info.message_id)
self.finish_call(suggested_question_run) self.finish_call(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
@ -347,10 +334,8 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"] attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = ( attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
trace_info.start_time or trace_info.message_data.created_at,) attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
attributes["end_time"] = (
trace_info.end_time or trace_info.message_data.updated_at,)
dataset_retrieval_run = WeaveTraceModel( dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -362,8 +347,7 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=[], file_list=[],
) )
self.start_call(dataset_retrieval_run, self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
parent_run_id=trace_info.message_id)
self.finish_call(dataset_retrieval_run) self.finish_call(dataset_retrieval_run)
def tool_trace(self, trace_info: ToolTraceInfo): def tool_trace(self, trace_info: ToolTraceInfo):
@ -377,13 +361,11 @@ class WeaveDataTrace(BaseTraceInstance):
op=trace_info.tool_name, op=trace_info.tool_name,
inputs=trace_info.tool_inputs, inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs, outputs=trace_info.tool_outputs,
file_list=[cast(str, trace_info.file_url) file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
] if trace_info.file_url else [],
attributes=attributes, attributes=attributes,
exception=trace_info.error, exception=trace_info.error,
) )
message_id = trace_info.message_id or getattr( message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
trace_info, "conversation_id", None)
message_id = message_id or None message_id = message_id or None
self.start_call(tool_run, parent_run_id=message_id) self.start_call(tool_run, parent_run_id=message_id)
self.finish_call(tool_run) self.finish_call(tool_run)
@ -409,14 +391,11 @@ class WeaveDataTrace(BaseTraceInstance):
def api_check(self): def api_check(self):
try: try:
login_kwargs = {
"key": self.weave_api_key,
"verify": True,
"relogin": True,
}
if self.host: if self.host:
login_kwargs["host"] = self.host login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
login_status = wandb.login(**login_kwargs) else:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status: if not login_status:
raise ValueError("Weave login failed") raise ValueError("Weave login failed")
else: else:
@ -427,8 +406,7 @@ class WeaveDataTrace(BaseTraceInstance):
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
call = self.weave_client.create_call( call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
self.calls[run_data.id] = call self.calls[run_data.id] = call
if parent_run_id: if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id self.calls[run_data.id].parent_id = parent_run_id
@ -436,7 +414,6 @@ class WeaveDataTrace(BaseTraceInstance):
def finish_call(self, run_data: WeaveTraceModel): def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id) call = self.calls.get(run_data.id)
if call: if call:
self.weave_client.finish_call( self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
call=call, output=run_data.outputs, exception=run_data.exception)
else: else:
raise ValueError(f"Call with id {run_data.id} not found") raise ValueError(f"Call with id {run_data.id} not found")

Loading…
Cancel
Save