feat: add host settings for dedicated cloud

pull/20765/head
Bharat Ramanathan 12 months ago
parent 85859b6723
commit dd28369cdf

@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig):
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig):
return v
@field_validator("host")
@classmethod
def validate_host(cls, v, info: ValidationInfo):
if v is not None and v != "":
if not v.startswith(("https://", "http://")):
raise ValueError("host must start with https:// or http://")
return v
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
@ -120,10 +120,12 @@ class OpsTraceManager:
if key in tracing_config:
if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config
new_config[key] = current_trace_config.get(key, tracing_config[key])
new_config[key] = current_trace_config.get(
key, tracing_config[key])
else:
# Otherwise, encrypt the key
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
new_config[key] = encrypt_token(
tenant_id, tracing_config[key])
for key in other_keys:
new_config[key] = tracing_config.get(key, "")
@ -223,7 +225,8 @@ class OpsTraceManager:
if app_id is None:
return None
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(
App).filter(App.id == app_id).first()
if app is None:
return None
@ -243,7 +246,8 @@ class OpsTraceManager:
return None
# decrypt_token
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
decrypt_trace_config = cls.get_decrypted_tracing_config(
app_id, tracing_provider)
if not decrypt_trace_config:
return None
@ -252,10 +256,12 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["config_class"],
)
decrypt_trace_config_key = str(decrypt_trace_config)
tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
tracing_instance = cls.ops_trace_instances_cache.get(
decrypt_trace_config_key)
if tracing_instance is None:
# create new tracing_instance and update the cache if it absent
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
tracing_instance = trace_instance(
config_class(**decrypt_trace_config))
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
logging.info(f"new tracing_instance for app_id: {app_id}")
return tracing_instance
@ -263,11 +269,13 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).filter(Message.id == message_id).first()
message_data = db.session.query(Message).filter(
Message.id == message_id).first()
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation_data = db.session.query(Conversation).filter(
Conversation.id == conversation_id).first()
if not conversation_data:
return None
@ -296,12 +304,15 @@ class OpsTraceManager:
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
raise ValueError(
f"Invalid tracing provider: {tracing_provider}")
else:
if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
raise ValueError(
f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app_config: Optional[App] = db.session.query(
App).filter(App.id == app_id).first()
if not app_config:
raise ValueError("App not found")
app_config.tracing = json.dumps(
@ -319,7 +330,8 @@ class OpsTraceManager:
:param app_id: app id
:return:
"""
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(
App).filter(App.id == app_id).first()
if not app:
raise ValueError("App not found")
if not app.tracing:
@ -439,7 +451,8 @@ class TraceTask:
return {}
with Session(db.engine) as session:
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run_stmt = select(WorkflowRun).where(
WorkflowRun.id == workflow_run_id)
workflow_run = session.scalars(workflow_run_stmt).first()
if not workflow_run:
raise ValueError("Workflow run not found")
@ -457,7 +470,8 @@ class TraceTask:
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
query = workflow_run_inputs.get(
"query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
@ -519,7 +533,8 @@ class TraceTask:
message_data = get_message_data(message_id)
if not message_data:
return {}
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_mode_stmt = select(Conversation.mode).where(
Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
return {}
@ -528,7 +543,8 @@ class TraceTask:
inputs = message_data.message
# get message file data
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
message_file_data = db.session.query(
MessageFile).filter_by(message_id=message_id).first()
file_list = []
if message_file_data and message_file_data.url is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
@ -561,7 +577,8 @@ class TraceTask:
outputs=message_data.answer,
file_list=file_list,
start_time=created_at,
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
end_time=created_at +
timedelta(seconds=message_data.provider_response_latency),
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
@ -588,9 +605,11 @@ class TraceTask:
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
db.session.query(WorkflowAppLog).filter_by(
workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
workflow_app_log_id = str(
workflow_app_log_data.id) if workflow_app_log_data else None
moderation_trace_info = ModerationTraceInfo(
message_id=workflow_app_log_id or message_id,
@ -628,9 +647,11 @@ class TraceTask:
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
db.session.query(WorkflowAppLog).filter_by(
workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
workflow_app_log_id = str(
workflow_app_log_data.id) if workflow_app_log_data else None
suggested_question_trace_info = SuggestedQuestionTraceInfo(
message_id=workflow_app_log_id or message_id,
@ -676,7 +697,8 @@ class TraceTask:
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
documents=[doc.model_dump()
for doc in documents] if documents else [],
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@ -720,7 +742,8 @@ class TraceTask:
}
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
message_file_data = db.session.query(
MessageFile).filter_by(message_id=message_id).first()
if message_file_data:
message_file_id = message_file_data.id if message_file_data else None
type = message_file_data.type
@ -788,7 +811,8 @@ class TraceTask:
trace_manager_timer: Optional[threading.Timer] = None
trace_manager_queue: queue.Queue = queue.Queue()
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
trace_manager_batch_size = int(
os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
class TraceQueueManager:
@ -809,7 +833,8 @@ class TraceQueueManager:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception as e:
logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
logging.exception(
f"Error adding trace task, trace_type {trace_task.trace_type}")
finally:
self.start_timer()
@ -833,7 +858,8 @@ class TraceQueueManager:
def start_timer(self):
global trace_manager_timer
if trace_manager_timer is None or not trace_manager_timer.is_alive():
trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
trace_manager_timer = threading.Timer(
trace_manager_interval, self.run)
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
trace_manager_timer.daemon = False
trace_manager_timer.start()
@ -851,7 +877,8 @@ class TraceQueueManager:
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
storage.save(
file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,

@ -40,16 +40,26 @@ class WeaveDataTrace(BaseTraceInstance):
self.weave_api_key = weave_config.api_key
self.project_name = weave_config.project
self.entity = weave_config.entity
# Login with API key first
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
self.host = weave_config.host
# Login with API key first, including host if provided
login_kwargs = {
"key": self.weave_api_key,
"verify": True,
"relogin": True,
}
if self.host:
login_kwargs["host"] = self.host
login_status = wandb.login(**login_kwargs)
if not login_status:
logger.error("Failed to login to Weights & Biases with the provided API key")
logger.error(
"Failed to login to Weights & Biases with the provided API key")
raise ValueError("Weave login failed")
# Then initialize weave client
self.weave_client = weave.init(
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
project_name=(
f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {}
@ -107,7 +117,8 @@ class WeaveDataTrace(BaseTraceInstance):
exception=trace_info.error,
file_list=[],
)
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
self.start_call(
message_run, parent_run_id=trace_info.workflow_run_id)
self.finish_call(message_run)
workflow_attributes = trace_info.metadata
@ -154,12 +165,14 @@ class WeaveDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
app_id = trace_info.metadata.get(
"app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
inputs = node_execution.process_data.get(
"prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
@ -168,7 +181,8 @@ class WeaveDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
node_total_tokens = execution_metadata.get(
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
{
@ -229,7 +243,8 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).filter(
EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
@ -307,8 +322,10 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["suggested_question"]
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
attributes["start_time"] = (
trace_info.start_time or message_data.created_at,)
attributes["end_time"] = (
trace_info.end_time or message_data.updated_at,)
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
@ -320,7 +337,8 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=[],
)
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
self.start_call(suggested_question_run,
parent_run_id=trace_info.message_id)
self.finish_call(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
@ -329,8 +347,10 @@ class WeaveDataTrace(BaseTraceInstance):
attributes = trace_info.metadata
attributes["message_id"] = trace_info.message_id
attributes["tags"] = ["dataset_retrieval"]
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
attributes["start_time"] = (
trace_info.start_time or trace_info.message_data.created_at,)
attributes["end_time"] = (
trace_info.end_time or trace_info.message_data.updated_at,)
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
@ -342,7 +362,8 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=[],
)
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
self.start_call(dataset_retrieval_run,
parent_run_id=trace_info.message_id)
self.finish_call(dataset_retrieval_run)
def tool_trace(self, trace_info: ToolTraceInfo):
@ -356,11 +377,13 @@ class WeaveDataTrace(BaseTraceInstance):
op=trace_info.tool_name,
inputs=trace_info.tool_inputs,
outputs=trace_info.tool_outputs,
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
file_list=[cast(str, trace_info.file_url)
] if trace_info.file_url else [],
attributes=attributes,
exception=trace_info.error,
)
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
message_id = trace_info.message_id or getattr(
trace_info, "conversation_id", None)
message_id = message_id or None
self.start_call(tool_run, parent_run_id=message_id)
self.finish_call(tool_run)
@ -386,7 +409,14 @@ class WeaveDataTrace(BaseTraceInstance):
def api_check(self):
try:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
login_kwargs = {
"key": self.weave_api_key,
"verify": True,
"relogin": True,
}
if self.host:
login_kwargs["host"] = self.host
login_status = wandb.login(**login_kwargs)
if not login_status:
raise ValueError("Weave login failed")
else:
@ -397,7 +427,8 @@ class WeaveDataTrace(BaseTraceInstance):
raise ValueError(f"Weave API check failed: {str(e)}")
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
call = self.weave_client.create_call(
op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
self.calls[run_data.id] = call
if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id
@ -405,6 +436,7 @@ class WeaveDataTrace(BaseTraceInstance):
def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id)
if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
self.weave_client.finish_call(
call=call, output=run_data.outputs, exception=run_data.exception)
else:
raise ValueError(f"Call with id {run_data.id} not found")

@ -55,6 +55,7 @@ const weaveConfigTemplate = {
entity: '',
project: '',
endpoint: '',
host: '',
}
const ProviderConfigModal: FC<Props> = ({
@ -226,6 +227,13 @@ const ProviderConfigModal: FC<Props> = ({
onChange={handleConfigChange('endpoint')}
placeholder={'https://trace.wandb.ai/'}
/>
<Field
label='Host'
labelClassName='!text-sm'
value={(config as WeaveConfig).host}
onChange={handleConfigChange('host')}
placeholder={'https://api.wandb.ai'}
/>
</>
)}
{type === TracingProvider.langSmith && (

@ -29,4 +29,5 @@ export type WeaveConfig = {
entity: string
project: string
endpoint: string
host: string
}

Loading…
Cancel
Save