diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9587aaa93f..e0dfe0c312 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -120,12 +120,10 @@ class OpsTraceManager: if key in tracing_config: if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config - new_config[key] = current_trace_config.get( - key, tracing_config[key]) + new_config[key] = current_trace_config.get(key, tracing_config[key]) else: # Otherwise, encrypt the key - new_config[key] = encrypt_token( - tenant_id, tracing_config[key]) + new_config[key] = encrypt_token(tenant_id, tracing_config[key]) for key in other_keys: new_config[key] = tracing_config.get(key, "") @@ -225,8 +223,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query( - App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if app is None: return None @@ -246,8 +243,7 @@ class OpsTraceManager: return None # decrypt_token - decrypt_trace_config = cls.get_decrypted_tracing_config( - app_id, tracing_provider) + decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) if not decrypt_trace_config: return None @@ -256,12 +252,10 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], ) decrypt_trace_config_key = str(decrypt_trace_config) - tracing_instance = cls.ops_trace_instances_cache.get( - decrypt_trace_config_key) + tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key) if tracing_instance is None: # create new tracing_instance and update the cache if it absent - tracing_instance = trace_instance( - config_class(**decrypt_trace_config)) + tracing_instance = trace_instance(config_class(**decrypt_trace_config)) cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance logging.info(f"new tracing_instance for app_id: {app_id}") return tracing_instance @@ -269,13 +263,11 @@ class OpsTraceManager: @classmethod def get_app_config_through_message_id(cls, message_id: str): app_model_config = None - message_data = db.session.query(Message).filter( - Message.id == message_id).first() + message_data = db.session.query(Message).filter(Message.id == message_id).first() if not message_data: return None conversation_id = message_data.conversation_id - conversation_data = db.session.query(Conversation).filter( - Conversation.id == conversation_id).first() + conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation_data: return None @@ -304,15 +296,12 @@ class OpsTraceManager: try: provider_config_map[tracing_provider] except KeyError: - raise ValueError( - f"Invalid tracing provider: {tracing_provider}") + raise ValueError(f"Invalid tracing provider: {tracing_provider}") else: if tracing_provider is not None: - raise ValueError( - f"Invalid tracing provider: {tracing_provider}") + raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query( - App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -330,8 +319,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query( - App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -451,8 +439,7 @@ class TraceTask: return {} with Session(db.engine) as session: - workflow_run_stmt = select(WorkflowRun).where( - WorkflowRun.id == workflow_run_id) + workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) workflow_run = session.scalars(workflow_run_stmt).first() if not workflow_run: raise ValueError("Workflow run not found") @@ -470,8 +457,7 @@ class TraceTask: total_tokens = workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get( - "query") or workflow_run_inputs.get("sys.query") or "" + query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( @@ -533,8 +519,7 @@ class TraceTask: message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode_stmt = select(Conversation.mode).where( - Conversation.id == message_data.conversation_id) + conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) conversation_mode = db.session.scalars(conversation_mode_stmt).all() if not conversation_mode or len(conversation_mode) == 0: return {} @@ -543,8 +528,7 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query( - MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -577,8 +561,7 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + - timedelta(seconds=message_data.provider_response_latency), + end_time=created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -605,11 +588,9 @@ class TraceTask: workflow_app_log_id = None if message_data.workflow_run_id: workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id).first() + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() ) - workflow_app_log_id = str( - workflow_app_log_data.id) if workflow_app_log_data else None + workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( message_id=workflow_app_log_id or message_id, @@ -647,11 +628,9 @@ class TraceTask: workflow_app_log_id = None if message_data.workflow_run_id: workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id).first() + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() ) - workflow_app_log_id = str( - workflow_app_log_data.id) if workflow_app_log_data else None + workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( message_id=workflow_app_log_id or message_id, @@ -697,8 +676,7 @@ class TraceTask: dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() - for doc in documents] if documents else [], + documents=[doc.model_dump() for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -742,8 +720,7 @@ class TraceTask: } file_url = "" - message_file_data = db.session.query( - MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type @@ -811,8 +788,7 @@ class TraceTask: trace_manager_timer: Optional[threading.Timer] = None trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) -trace_manager_batch_size = int( - os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) +trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) class TraceQueueManager: @@ -833,8 +809,7 @@ class TraceQueueManager: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception as e: - logging.exception( - f"Error adding trace task, trace_type {trace_task.trace_type}") + logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}") finally: self.start_timer() @@ -858,8 +833,7 @@ class TraceQueueManager: def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer( - trace_manager_interval, self.run) + trace_manager_timer = threading.Timer(trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() @@ -877,8 +851,7 @@ class TraceQueueManager: trace_info=trace_info.model_dump() if trace_info else None, ) file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" - storage.save( - file_path, task_data.model_dump_json().encode("utf-8")) + storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, "app_id": task.app_id, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index aede340bae..5110904014 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -4,10 +4,10 @@ import uuid from datetime import datetime, timedelta from typing import Any, Optional, cast -import wandb import weave from sqlalchemy.orm import sessionmaker +import wandb from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( @@ -43,23 +43,18 @@ class WeaveDataTrace(BaseTraceInstance): self.host = weave_config.host # Login with API key first, including host if provided - login_kwargs = { - "key": self.weave_api_key, - "verify": True, - "relogin": True, - } if self.host: - login_kwargs["host"] = self.host - login_status = wandb.login(**login_kwargs) + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host) + else: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if not login_status: - logger.error( - "Failed to login to Weights & Biases with the provided API key") + logger.error("Failed to login to Weights & Biases with the provided API key") raise ValueError("Weave login failed") # Then initialize weave client self.weave_client = weave.init( - project_name=( - f"{self.entity}/{self.project_name}" if self.entity else self.project_name) + project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name) ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.calls: dict[str, Any] = {} @@ -117,8 +112,7 @@ class WeaveDataTrace(BaseTraceInstance): exception=trace_info.error, file_list=[], ) - self.start_call( - message_run, parent_run_id=trace_info.workflow_run_id) + self.start_call(message_run, parent_run_id=trace_info.workflow_run_id) self.finish_call(message_run) workflow_attributes = trace_info.metadata @@ -165,14 +159,12 @@ class WeaveDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = trace_info.tenant_id # Use from trace_info instead - app_id = trace_info.metadata.get( - "app_id") # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status if node_type == NodeType.LLM: - inputs = node_execution.process_data.get( - "prompts", {}) if node_execution.process_data else {} + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs if node_execution.inputs else {} outputs = node_execution.outputs if node_execution.outputs else {} @@ -181,8 +173,7 @@ class WeaveDataTrace(BaseTraceInstance): finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = node_execution.metadata if node_execution.metadata else {} - node_total_tokens = execution_metadata.get( - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 + node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( { @@ -243,8 +234,7 @@ class WeaveDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter( - EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id @@ -322,10 +312,8 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["message_id"] = trace_info.message_id attributes["tags"] = ["suggested_question"] - attributes["start_time"] = ( - trace_info.start_time or message_data.created_at,) - attributes["end_time"] = ( - trace_info.end_time or message_data.updated_at,) + attributes["start_time"] = (trace_info.start_time or message_data.created_at,) + attributes["end_time"] = (trace_info.end_time or message_data.updated_at,) suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -337,8 +325,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(suggested_question_run, - parent_run_id=trace_info.message_id) + self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) self.finish_call(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): @@ -347,10 +334,8 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["message_id"] = trace_info.message_id attributes["tags"] = ["dataset_retrieval"] - attributes["start_time"] = ( - trace_info.start_time or trace_info.message_data.created_at,) - attributes["end_time"] = ( - trace_info.end_time or trace_info.message_data.updated_at,) + attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) + attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -362,8 +347,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(dataset_retrieval_run, - parent_run_id=trace_info.message_id) + self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) self.finish_call(dataset_retrieval_run) def tool_trace(self, trace_info: ToolTraceInfo): @@ -377,13 +361,11 @@ class WeaveDataTrace(BaseTraceInstance): op=trace_info.tool_name, inputs=trace_info.tool_inputs, outputs=trace_info.tool_outputs, - file_list=[cast(str, trace_info.file_url) - ] if trace_info.file_url else [], + file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [], attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or getattr( - trace_info, "conversation_id", None) + message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) message_id = message_id or None self.start_call(tool_run, parent_run_id=message_id) self.finish_call(tool_run) @@ -409,14 +391,11 @@ class WeaveDataTrace(BaseTraceInstance): def api_check(self): try: - login_kwargs = { - "key": self.weave_api_key, - "verify": True, - "relogin": True, - } if self.host: - login_kwargs["host"] = self.host - login_status = wandb.login(**login_kwargs) + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host) + else: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if not login_status: raise ValueError("Weave login failed") else: @@ -427,8 +406,7 @@ class WeaveDataTrace(BaseTraceInstance): raise ValueError(f"Weave API check failed: {str(e)}") def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): - call = self.weave_client.create_call( - op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) + call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self.calls[run_data.id] = call if parent_run_id: self.calls[run_data.id].parent_id = parent_run_id @@ -436,7 +414,6 @@ class WeaveDataTrace(BaseTraceInstance): def finish_call(self, run_data: WeaveTraceModel): call = self.calls.get(run_data.id) if call: - self.weave_client.finish_call( - call=call, output=run_data.outputs, exception=run_data.exception) + self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) else: raise ValueError(f"Call with id {run_data.id} not found")