diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index f2d1bd305a..c988bf48d1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig): entity: str | None = None project: str endpoint: str = "https://trace.wandb.ai" + host: str | None = None @field_validator("endpoint") @classmethod @@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig): return v + @field_validator("host") + @classmethod + def validate_host(cls, v, info: ValidationInfo): + if v is not None and v != "": + if not v.startswith(("https://", "http://")): + raise ValueError("host must start with https:// or http://") + return v + OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index dc4cfc48db..9587aaa93f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): return { "config_class": WeaveConfig, "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint"], + "other_keys": ["project", "entity", "endpoint", "host"], "trace_instance": WeaveDataTrace, } @@ -120,10 +120,12 @@ class OpsTraceManager: if key in tracing_config: if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config - new_config[key] = current_trace_config.get(key, tracing_config[key]) + new_config[key] = current_trace_config.get( + key, tracing_config[key]) else: # Otherwise, encrypt the key - new_config[key] = encrypt_token(tenant_id, tracing_config[key]) + new_config[key] = encrypt_token( + tenant_id, tracing_config[key]) for key in other_keys: new_config[key] = tracing_config.get(key, "") @@ -223,7 +225,8 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query( + App).filter(App.id == app_id).first() if app is None: return None @@ -243,7 +246,8 @@ class OpsTraceManager: return None # decrypt_token - decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) + decrypt_trace_config = cls.get_decrypted_tracing_config( + app_id, tracing_provider) if not decrypt_trace_config: return None @@ -252,10 +256,12 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], ) decrypt_trace_config_key = str(decrypt_trace_config) - tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key) + tracing_instance = cls.ops_trace_instances_cache.get( + decrypt_trace_config_key) if tracing_instance is None: # create new tracing_instance and update the cache if it absent - tracing_instance = trace_instance(config_class(**decrypt_trace_config)) + tracing_instance = trace_instance( + config_class(**decrypt_trace_config)) cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance logging.info(f"new tracing_instance for app_id: {app_id}") return tracing_instance @@ -263,11 +269,13 @@ class OpsTraceManager: @classmethod def get_app_config_through_message_id(cls, message_id: str): app_model_config = None - message_data = db.session.query(Message).filter(Message.id == message_id).first() + message_data = db.session.query(Message).filter( + Message.id == message_id).first() if not message_data: return None conversation_id = message_data.conversation_id - conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation_data = db.session.query(Conversation).filter( + Conversation.id == conversation_id).first() if not conversation_data: return None @@ -296,12 +304,15 @@ class OpsTraceManager: try: provider_config_map[tracing_provider] except KeyError: - raise ValueError(f"Invalid tracing provider: {tracing_provider}") + raise ValueError( + f"Invalid tracing provider: {tracing_provider}") else: if tracing_provider is not None: - raise ValueError(f"Invalid tracing provider: {tracing_provider}") + raise ValueError( + f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query( + App).filter(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -319,7 +330,8 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query( + App).filter(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -439,7 +451,8 @@ class TraceTask: return {} with Session(db.engine) as session: - workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run_stmt = select(WorkflowRun).where( + WorkflowRun.id == workflow_run_id) workflow_run = session.scalars(workflow_run_stmt).first() if not workflow_run: raise ValueError("Workflow run not found") @@ -457,7 +470,8 @@ class TraceTask: total_tokens = workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" + query = workflow_run_inputs.get( + "query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( @@ -519,7 +533,8 @@ class TraceTask: message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) + conversation_mode_stmt = select(Conversation.mode).where( + Conversation.id == message_data.conversation_id) conversation_mode = db.session.scalars(conversation_mode_stmt).all() if not conversation_mode or len(conversation_mode) == 0: return {} @@ -528,7 +543,8 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.query( + MessageFile).filter_by(message_id=message_id).first() file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -561,7 +577,8 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=created_at + + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -588,9 +605,11 @@ class TraceTask: workflow_app_log_id = None if message_data.workflow_run_id: workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + db.session.query(WorkflowAppLog).filter_by( + workflow_run_id=message_data.workflow_run_id).first() ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None + workflow_app_log_id = str( + workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( message_id=workflow_app_log_id or message_id, @@ -628,9 +647,11 @@ class TraceTask: workflow_app_log_id = None if message_data.workflow_run_id: workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + db.session.query(WorkflowAppLog).filter_by( + workflow_run_id=message_data.workflow_run_id).first() ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None + workflow_app_log_id = str( + workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( message_id=workflow_app_log_id or message_id, @@ -676,7 +697,8 @@ class TraceTask: dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=[doc.model_dump() + for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -720,7 +742,8 @@ class TraceTask: } file_url = "" - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.query( + MessageFile).filter_by(message_id=message_id).first() if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type @@ -788,7 +811,8 @@ class TraceTask: trace_manager_timer: Optional[threading.Timer] = None trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) -trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) +trace_manager_batch_size = int( + os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) class TraceQueueManager: @@ -809,7 +833,8 @@ class TraceQueueManager: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception as e: - logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}") + logging.exception( + f"Error adding trace task, trace_type {trace_task.trace_type}") finally: self.start_timer() @@ -833,7 +858,8 @@ class TraceQueueManager: def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer(trace_manager_interval, self.run) + trace_manager_timer = threading.Timer( + trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() @@ -851,7 +877,8 @@ class TraceQueueManager: trace_info=trace_info.model_dump() if trace_info else None, ) file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" - storage.save(file_path, task_data.model_dump_json().encode("utf-8")) + storage.save( + file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, "app_id": task.app_id, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index cfc8a505bb..aede340bae 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -40,16 +40,26 @@ class WeaveDataTrace(BaseTraceInstance): self.weave_api_key = weave_config.api_key self.project_name = weave_config.project self.entity = weave_config.entity - - # Login with API key first - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + self.host = weave_config.host + + # Login with API key first, including host if provided + login_kwargs = { + "key": self.weave_api_key, + "verify": True, + "relogin": True, + } + if self.host: + login_kwargs["host"] = self.host + login_status = wandb.login(**login_kwargs) if not login_status: - logger.error("Failed to login to Weights & Biases with the provided API key") + logger.error( + "Failed to login to Weights & Biases with the provided API key") raise ValueError("Weave login failed") # Then initialize weave client self.weave_client = weave.init( - project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name) + project_name=( + f"{self.entity}/{self.project_name}" if self.entity else self.project_name) ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.calls: dict[str, Any] = {} @@ -107,7 +117,8 @@ class WeaveDataTrace(BaseTraceInstance): exception=trace_info.error, file_list=[], ) - self.start_call(message_run, parent_run_id=trace_info.workflow_run_id) + self.start_call( + message_run, parent_run_id=trace_info.workflow_run_id) self.finish_call(message_run) workflow_attributes = trace_info.metadata @@ -154,12 +165,14 @@ class WeaveDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = trace_info.tenant_id # Use from trace_info instead - app_id = trace_info.metadata.get("app_id") # Use from trace_info instead + app_id = trace_info.metadata.get( + "app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status if node_type == NodeType.LLM: - inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} + inputs = node_execution.process_data.get( + "prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs if node_execution.inputs else {} outputs = node_execution.outputs if node_execution.outputs else {} @@ -168,7 +181,8 @@ class WeaveDataTrace(BaseTraceInstance): finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = node_execution.metadata if node_execution.metadata else {} - node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 + node_total_tokens = execution_metadata.get( + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( { @@ -229,7 +243,8 @@ class WeaveDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).filter( + EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id @@ -307,8 +322,10 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["message_id"] = trace_info.message_id attributes["tags"] = ["suggested_question"] - attributes["start_time"] = (trace_info.start_time or message_data.created_at,) - attributes["end_time"] = (trace_info.end_time or message_data.updated_at,) + attributes["start_time"] = ( + trace_info.start_time or message_data.created_at,) + attributes["end_time"] = ( + trace_info.end_time or message_data.updated_at,) suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -320,7 +337,8 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) + self.start_call(suggested_question_run, + parent_run_id=trace_info.message_id) self.finish_call(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): @@ -329,8 +347,10 @@ class WeaveDataTrace(BaseTraceInstance): attributes = trace_info.metadata attributes["message_id"] = trace_info.message_id attributes["tags"] = ["dataset_retrieval"] - attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) - attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) + attributes["start_time"] = ( + trace_info.start_time or trace_info.message_data.created_at,) + attributes["end_time"] = ( + trace_info.end_time or trace_info.message_data.updated_at,) dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), @@ -342,7 +362,8 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) + self.start_call(dataset_retrieval_run, + parent_run_id=trace_info.message_id) self.finish_call(dataset_retrieval_run) def tool_trace(self, trace_info: ToolTraceInfo): @@ -356,11 +377,13 @@ class WeaveDataTrace(BaseTraceInstance): op=trace_info.tool_name, inputs=trace_info.tool_inputs, outputs=trace_info.tool_outputs, - file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [], + file_list=[cast(str, trace_info.file_url) + ] if trace_info.file_url else [], attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) + message_id = trace_info.message_id or getattr( + trace_info, "conversation_id", None) message_id = message_id or None self.start_call(tool_run, parent_run_id=message_id) self.finish_call(tool_run) @@ -386,7 +409,14 @@ class WeaveDataTrace(BaseTraceInstance): def api_check(self): try: - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + login_kwargs = { + "key": self.weave_api_key, + "verify": True, + "relogin": True, + } + if self.host: + login_kwargs["host"] = self.host + login_status = wandb.login(**login_kwargs) if not login_status: raise ValueError("Weave login failed") else: @@ -397,7 +427,8 @@ class WeaveDataTrace(BaseTraceInstance): raise ValueError(f"Weave API check failed: {str(e)}") def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): - call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) + call = self.weave_client.create_call( + op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) self.calls[run_data.id] = call if parent_run_id: self.calls[run_data.id].parent_id = parent_run_id @@ -405,6 +436,7 @@ class WeaveDataTrace(BaseTraceInstance): def finish_call(self, run_data: WeaveTraceModel): call = self.calls.get(run_data.id) if call: - self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) + self.weave_client.finish_call( + call=call, output=run_data.outputs, exception=run_data.exception) else: raise ValueError(f"Call with id {run_data.id} not found") diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index c0b52a9b10..b6c97add48 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -55,6 +55,7 @@ const weaveConfigTemplate = { entity: '', project: '', endpoint: '', + host: '', } const ProviderConfigModal: FC = ({ @@ -226,6 +227,13 @@ const ProviderConfigModal: FC = ({ onChange={handleConfigChange('endpoint')} placeholder={'https://trace.wandb.ai/'} /> + )} {type === TracingProvider.langSmith && ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts index 386c58974e..ed468caf65 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts @@ -29,4 +29,5 @@ export type WeaveConfig = { entity: string project: string endpoint: string + host: string }