Merge branch 'main' into e-300
commit
62347206c0
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class VastbaseVectorConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for Vector (Vastbase with vector extension)
|
||||||
|
"""
|
||||||
|
|
||||||
|
VASTBASE_HOST: Optional[str] = Field(
|
||||||
|
description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
VASTBASE_PORT: PositiveInt = Field(
|
||||||
|
description="Port number on which the Vastbase server is listening (default is 5432)",
|
||||||
|
default=5432,
|
||||||
|
)
|
||||||
|
|
||||||
|
VASTBASE_USER: Optional[str] = Field(
|
||||||
|
description="Username for authenticating with the Vastbase database",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
VASTBASE_PASSWORD: Optional[str] = Field(
|
||||||
|
description="Password for authenticating with the Vastbase database",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
VASTBASE_DATABASE: Optional[str] = Field(
|
||||||
|
description="Name of the Vastbase database to connect to",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
VASTBASE_MIN_CONNECTION: PositiveInt = Field(
|
||||||
|
description="Min connection of the Vastbase database",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
VASTBASE_MAX_CONNECTION: PositiveInt = Field(
|
||||||
|
description="Max connection of the Vastbase database",
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
@ -0,0 +1,97 @@
|
|||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
|
from core.ops.utils import replace_text_with_content
|
||||||
|
|
||||||
|
|
||||||
|
class WeaveTokenUsage(BaseModel):
|
||||||
|
input_tokens: Optional[int] = None
|
||||||
|
output_tokens: Optional[int] = None
|
||||||
|
total_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WeaveMultiModel(BaseModel):
|
||||||
|
file_list: Optional[list[str]] = Field(None, description="List of files")
|
||||||
|
|
||||||
|
|
||||||
|
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||||
|
id: str = Field(..., description="ID of the trace")
|
||||||
|
op: str = Field(..., description="Name of the operation")
|
||||||
|
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||||
|
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||||
|
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||||
|
None, description="Metadata and attributes associated with trace"
|
||||||
|
)
|
||||||
|
exception: Optional[str] = Field(None, description="Exception message of the trace")
|
||||||
|
|
||||||
|
@field_validator("inputs", "outputs")
|
||||||
|
@classmethod
|
||||||
|
def ensure_dict(cls, v, info: ValidationInfo):
|
||||||
|
field_name = info.field_name
|
||||||
|
values = info.data
|
||||||
|
if v == {} or v is None:
|
||||||
|
return v
|
||||||
|
usage_metadata = {
|
||||||
|
"input_tokens": values.get("input_tokens", 0),
|
||||||
|
"output_tokens": values.get("output_tokens", 0),
|
||||||
|
"total_tokens": values.get("total_tokens", 0),
|
||||||
|
}
|
||||||
|
file_list = values.get("file_list", [])
|
||||||
|
if isinstance(v, str):
|
||||||
|
if field_name == "inputs":
|
||||||
|
return {
|
||||||
|
"messages": {
|
||||||
|
"role": "user",
|
||||||
|
"content": v,
|
||||||
|
"usage_metadata": usage_metadata,
|
||||||
|
"file_list": file_list,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif field_name == "outputs":
|
||||||
|
return {
|
||||||
|
"choices": {
|
||||||
|
"role": "ai",
|
||||||
|
"content": v,
|
||||||
|
"usage_metadata": usage_metadata,
|
||||||
|
"file_list": file_list,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif isinstance(v, list):
|
||||||
|
data = {}
|
||||||
|
if len(v) > 0 and isinstance(v[0], dict):
|
||||||
|
# rename text to content
|
||||||
|
v = replace_text_with_content(data=v)
|
||||||
|
if field_name == "inputs":
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||||
|
]
|
||||||
|
if isinstance(v, list)
|
||||||
|
else v,
|
||||||
|
}
|
||||||
|
elif field_name == "outputs":
|
||||||
|
data = {
|
||||||
|
"choices": {
|
||||||
|
"role": "ai",
|
||||||
|
"content": v,
|
||||||
|
"usage_metadata": usage_metadata,
|
||||||
|
"file_list": file_list,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"choices": {
|
||||||
|
"role": "ai" if field_name == "outputs" else "user",
|
||||||
|
"content": str(v),
|
||||||
|
"usage_metadata": usage_metadata,
|
||||||
|
"file_list": file_list,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if isinstance(v, dict):
|
||||||
|
v["usage_metadata"] = usage_metadata
|
||||||
|
v["file_list"] = file_list
|
||||||
|
return v
|
||||||
|
return v
|
||||||
@ -0,0 +1,420 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
import weave
|
||||||
|
|
||||||
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
|
from core.ops.entities.config_entity import WeaveConfig
|
||||||
|
from core.ops.entities.trace_entity import (
|
||||||
|
BaseTraceInfo,
|
||||||
|
DatasetRetrievalTraceInfo,
|
||||||
|
GenerateNameTraceInfo,
|
||||||
|
MessageTraceInfo,
|
||||||
|
ModerationTraceInfo,
|
||||||
|
SuggestedQuestionTraceInfo,
|
||||||
|
ToolTraceInfo,
|
||||||
|
TraceTaskName,
|
||||||
|
WorkflowTraceInfo,
|
||||||
|
)
|
||||||
|
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import EndUser, MessageFile
|
||||||
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WeaveDataTrace(BaseTraceInstance):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weave_config: WeaveConfig,
|
||||||
|
):
|
||||||
|
super().__init__(weave_config)
|
||||||
|
self.weave_api_key = weave_config.api_key
|
||||||
|
self.project_name = weave_config.project
|
||||||
|
self.entity = weave_config.entity
|
||||||
|
|
||||||
|
# Login with API key first
|
||||||
|
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||||
|
if not login_status:
|
||||||
|
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||||
|
raise ValueError("Weave login failed")
|
||||||
|
|
||||||
|
# Then initialize weave client
|
||||||
|
self.weave_client = weave.init(
|
||||||
|
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
|
||||||
|
)
|
||||||
|
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||||
|
self.calls: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def get_project_url(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
||||||
|
return project_url
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Weave get run url failed: {str(e)}")
|
||||||
|
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||||
|
|
||||||
|
def trace(self, trace_info: BaseTraceInfo):
|
||||||
|
logger.debug(f"Trace info: {trace_info}")
|
||||||
|
if isinstance(trace_info, WorkflowTraceInfo):
|
||||||
|
self.workflow_trace(trace_info)
|
||||||
|
if isinstance(trace_info, MessageTraceInfo):
|
||||||
|
self.message_trace(trace_info)
|
||||||
|
if isinstance(trace_info, ModerationTraceInfo):
|
||||||
|
self.moderation_trace(trace_info)
|
||||||
|
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||||
|
self.suggested_question_trace(trace_info)
|
||||||
|
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||||
|
self.dataset_retrieval_trace(trace_info)
|
||||||
|
if isinstance(trace_info, ToolTraceInfo):
|
||||||
|
self.tool_trace(trace_info)
|
||||||
|
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||||
|
self.generate_name_trace(trace_info)
|
||||||
|
|
||||||
|
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||||
|
trace_id = trace_info.message_id or trace_info.workflow_run_id
|
||||||
|
if trace_info.start_time is None:
|
||||||
|
trace_info.start_time = datetime.now()
|
||||||
|
|
||||||
|
if trace_info.message_id:
|
||||||
|
message_attributes = trace_info.metadata
|
||||||
|
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||||
|
|
||||||
|
message_attributes["message_id"] = trace_info.message_id
|
||||||
|
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||||
|
message_attributes["trace_id"] = trace_id
|
||||||
|
message_attributes["start_time"] = trace_info.start_time
|
||||||
|
message_attributes["end_time"] = trace_info.end_time
|
||||||
|
message_attributes["tags"] = ["message", "workflow"]
|
||||||
|
|
||||||
|
message_run = WeaveTraceModel(
|
||||||
|
id=trace_info.message_id,
|
||||||
|
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||||
|
inputs=dict(trace_info.workflow_run_inputs),
|
||||||
|
outputs=dict(trace_info.workflow_run_outputs),
|
||||||
|
total_tokens=trace_info.total_tokens,
|
||||||
|
attributes=message_attributes,
|
||||||
|
exception=trace_info.error,
|
||||||
|
file_list=[],
|
||||||
|
)
|
||||||
|
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
|
||||||
|
self.finish_call(message_run)
|
||||||
|
|
||||||
|
workflow_attributes = trace_info.metadata
|
||||||
|
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||||
|
workflow_attributes["trace_id"] = trace_id
|
||||||
|
workflow_attributes["start_time"] = trace_info.start_time
|
||||||
|
workflow_attributes["end_time"] = trace_info.end_time
|
||||||
|
workflow_attributes["tags"] = ["workflow"]
|
||||||
|
|
||||||
|
workflow_run = WeaveTraceModel(
|
||||||
|
file_list=trace_info.file_list,
|
||||||
|
total_tokens=trace_info.total_tokens,
|
||||||
|
id=trace_info.workflow_run_id,
|
||||||
|
op=str(TraceTaskName.WORKFLOW_TRACE.value),
|
||||||
|
inputs=dict(trace_info.workflow_run_inputs),
|
||||||
|
outputs=dict(trace_info.workflow_run_outputs),
|
||||||
|
attributes=workflow_attributes,
|
||||||
|
exception=trace_info.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||||
|
|
||||||
|
# through workflow_run_id get all_nodes_execution
|
||||||
|
workflow_nodes_execution_id_records = (
|
||||||
|
db.session.query(WorkflowNodeExecution.id)
|
||||||
|
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||||
|
node_execution = (
|
||||||
|
db.session.query(
|
||||||
|
WorkflowNodeExecution.id,
|
||||||
|
WorkflowNodeExecution.tenant_id,
|
||||||
|
WorkflowNodeExecution.app_id,
|
||||||
|
WorkflowNodeExecution.title,
|
||||||
|
WorkflowNodeExecution.node_type,
|
||||||
|
WorkflowNodeExecution.status,
|
||||||
|
WorkflowNodeExecution.inputs,
|
||||||
|
WorkflowNodeExecution.outputs,
|
||||||
|
WorkflowNodeExecution.created_at,
|
||||||
|
WorkflowNodeExecution.elapsed_time,
|
||||||
|
WorkflowNodeExecution.process_data,
|
||||||
|
WorkflowNodeExecution.execution_metadata,
|
||||||
|
)
|
||||||
|
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not node_execution:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_execution_id = node_execution.id
|
||||||
|
tenant_id = node_execution.tenant_id
|
||||||
|
app_id = node_execution.app_id
|
||||||
|
node_name = node_execution.title
|
||||||
|
node_type = node_execution.node_type
|
||||||
|
status = node_execution.status
|
||||||
|
if node_type == "llm":
|
||||||
|
inputs = (
|
||||||
|
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||||
|
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||||
|
created_at = node_execution.created_at or datetime.now()
|
||||||
|
elapsed_time = node_execution.elapsed_time
|
||||||
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
|
execution_metadata = (
|
||||||
|
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||||
|
)
|
||||||
|
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||||
|
attributes = execution_metadata.copy()
|
||||||
|
attributes.update(
|
||||||
|
{
|
||||||
|
"workflow_run_id": trace_info.workflow_run_id,
|
||||||
|
"node_execution_id": node_execution_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"app_name": node_name,
|
||||||
|
"node_type": node_type,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||||
|
if process_data and process_data.get("model_mode") == "chat":
|
||||||
|
attributes.update(
|
||||||
|
{
|
||||||
|
"ls_provider": process_data.get("model_provider", ""),
|
||||||
|
"ls_model_name": process_data.get("model_name", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
attributes["tags"] = ["node_execution"]
|
||||||
|
attributes["start_time"] = created_at
|
||||||
|
attributes["end_time"] = finished_at
|
||||||
|
attributes["elapsed_time"] = elapsed_time
|
||||||
|
attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||||
|
attributes["trace_id"] = trace_id
|
||||||
|
node_run = WeaveTraceModel(
|
||||||
|
total_tokens=node_total_tokens,
|
||||||
|
op=node_type,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
file_list=trace_info.file_list,
|
||||||
|
attributes=attributes,
|
||||||
|
id=node_execution_id,
|
||||||
|
exception=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
|
||||||
|
self.finish_call(node_run)
|
||||||
|
|
||||||
|
self.finish_call(workflow_run)
|
||||||
|
|
||||||
|
def message_trace(self, trace_info: MessageTraceInfo):
|
||||||
|
# get message file data
|
||||||
|
file_list = cast(list[str], trace_info.file_list) or []
|
||||||
|
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||||
|
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||||
|
file_list.append(file_url)
|
||||||
|
attributes = trace_info.metadata
|
||||||
|
message_data = trace_info.message_data
|
||||||
|
if message_data is None:
|
||||||
|
return
|
||||||
|
message_id = message_data.id
|
||||||
|
|
||||||
|
user_id = message_data.from_account_id
|
||||||
|
attributes["user_id"] = user_id
|
||||||
|
|
||||||
|
if message_data.from_end_user_id:
|
||||||
|
end_user_data: Optional[EndUser] = (
|
||||||
|
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||||
|
)
|
||||||
|
if end_user_data is not None:
|
||||||
|
end_user_id = end_user_data.session_id
|
||||||
|
attributes["end_user_id"] = end_user_id
|
||||||
|
|
||||||
|
attributes["message_id"] = message_id
|
||||||
|
attributes["start_time"] = trace_info.start_time
|
||||||
|
attributes["end_time"] = trace_info.end_time
|
||||||
|
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
|
||||||
|
message_run = WeaveTraceModel(
|
||||||
|
id=message_id,
|
||||||
|
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||||
|
input_tokens=trace_info.message_tokens,
|
||||||
|
output_tokens=trace_info.answer_tokens,
|
||||||
|
total_tokens=trace_info.total_tokens,
|
||||||
|
inputs=trace_info.inputs,
|
||||||
|
outputs=trace_info.outputs,
|
||||||
|
exception=trace_info.error,
|
||||||
|
file_list=file_list,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
self.start_call(message_run)
|
||||||
|
|
||||||
|
# create llm run parented to message run
|
||||||
|
llm_run = WeaveTraceModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
input_tokens=trace_info.message_tokens,
|
||||||
|
output_tokens=trace_info.answer_tokens,
|
||||||
|
total_tokens=trace_info.total_tokens,
|
||||||
|
op="llm",
|
||||||
|
inputs=trace_info.inputs,
|
||||||
|
outputs=trace_info.outputs,
|
||||||
|
attributes=attributes,
|
||||||
|
file_list=[],
|
||||||
|
exception=None,
|
||||||
|
)
|
||||||
|
self.start_call(
|
||||||
|
llm_run,
|
||||||
|
parent_run_id=message_id,
|
||||||
|
)
|
||||||
|
self.finish_call(llm_run)
|
||||||
|
self.finish_call(message_run)
|
||||||
|
|
||||||
|
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||||
|
if trace_info.message_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
attributes = trace_info.metadata
|
||||||
|
attributes["tags"] = ["moderation"]
|
||||||
|
attributes["message_id"] = trace_info.message_id
|
||||||
|
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
||||||
|
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
||||||
|
|
||||||
|
moderation_run = WeaveTraceModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
op=str(TraceTaskName.MODERATION_TRACE.value),
|
||||||
|
inputs=trace_info.inputs,
|
||||||
|
outputs={
|
||||||
|
"action": trace_info.action,
|
||||||
|
"flagged": trace_info.flagged,
|
||||||
|
"preset_response": trace_info.preset_response,
|
||||||
|
"inputs": trace_info.inputs,
|
||||||
|
},
|
||||||
|
attributes=attributes,
|
||||||
|
exception=getattr(trace_info, "error", None),
|
||||||
|
file_list=[],
|
||||||
|
)
|
||||||
|
self.start_call(moderation_run, parent_run_id=trace_info.message_id)
|
||||||
|
self.finish_call(moderation_run)
|
||||||
|
|
||||||
|
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||||
|
message_data = trace_info.message_data
|
||||||
|
if message_data is None:
|
||||||
|
return
|
||||||
|
attributes = trace_info.metadata
|
||||||
|
attributes["message_id"] = trace_info.message_id
|
||||||
|
attributes["tags"] = ["suggested_question"]
|
||||||
|
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
|
||||||
|
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
|
||||||
|
|
||||||
|
suggested_question_run = WeaveTraceModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
|
||||||
|
inputs=trace_info.inputs,
|
||||||
|
outputs=trace_info.suggested_question,
|
||||||
|
attributes=attributes,
|
||||||
|
exception=trace_info.error,
|
||||||
|
file_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
|
||||||
|
self.finish_call(suggested_question_run)
|
||||||
|
|
||||||
|
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||||
|
if trace_info.message_data is None:
|
||||||
|
return
|
||||||
|
attributes = trace_info.metadata
|
||||||
|
attributes["message_id"] = trace_info.message_id
|
||||||
|
attributes["tags"] = ["dataset_retrieval"]
|
||||||
|
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
||||||
|
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
||||||
|
|
||||||
|
dataset_retrieval_run = WeaveTraceModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
|
||||||
|
inputs=trace_info.inputs,
|
||||||
|
outputs={"documents": trace_info.documents},
|
||||||
|
attributes=attributes,
|
||||||
|
exception=getattr(trace_info, "error", None),
|
||||||
|
file_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
|
||||||
|
self.finish_call(dataset_retrieval_run)
|
||||||
|
|
||||||
|
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||||
|
attributes = trace_info.metadata
|
||||||
|
attributes["tags"] = ["tool", trace_info.tool_name]
|
||||||
|
attributes["start_time"] = trace_info.start_time
|
||||||
|
attributes["end_time"] = trace_info.end_time
|
||||||
|
|
||||||
|
tool_run = WeaveTraceModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
op=trace_info.tool_name,
|
||||||
|
inputs=trace_info.tool_inputs,
|
||||||
|
outputs=trace_info.tool_outputs,
|
||||||
|
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
|
||||||
|
attributes=attributes,
|
||||||
|
exception=trace_info.error,
|
||||||
|
)
|
||||||
|
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||||
|
message_id = message_id or None
|
||||||
|
self.start_call(tool_run, parent_run_id=message_id)
|
||||||
|
self.finish_call(tool_run)
|
||||||
|
|
||||||
|
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||||
|
attributes = trace_info.metadata
|
||||||
|
attributes["tags"] = ["generate_name"]
|
||||||
|
attributes["start_time"] = trace_info.start_time
|
||||||
|
attributes["end_time"] = trace_info.end_time
|
||||||
|
|
||||||
|
name_run = WeaveTraceModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
|
||||||
|
inputs=trace_info.inputs,
|
||||||
|
outputs=trace_info.outputs,
|
||||||
|
attributes=attributes,
|
||||||
|
exception=getattr(trace_info, "error", None),
|
||||||
|
file_list=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_call(name_run)
|
||||||
|
self.finish_call(name_run)
|
||||||
|
|
||||||
|
def api_check(self):
|
||||||
|
try:
|
||||||
|
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||||
|
if not login_status:
|
||||||
|
raise ValueError("Weave login failed")
|
||||||
|
else:
|
||||||
|
print("Weave login successful")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Weave API check failed: {str(e)}")
|
||||||
|
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||||
|
|
||||||
|
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
|
||||||
|
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
||||||
|
self.calls[run_data.id] = call
|
||||||
|
if parent_run_id:
|
||||||
|
self.calls[run_data.id].parent_id = parent_run_id
|
||||||
|
|
||||||
|
def finish_call(self, run_data: WeaveTraceModel):
|
||||||
|
call = self.calls.get(run_data.id)
|
||||||
|
if call:
|
||||||
|
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Call with id {run_data.id} not found")
|
||||||
@ -1,7 +1,7 @@
|
|||||||
from core.plugin.manager.base import BasePluginManager
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
class PluginAssetManager(BasePluginManager):
|
class PluginAssetManager(BasePluginClient):
|
||||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Fetch an asset by id.
|
Fetch an asset by id.
|
||||||
@ -1,9 +1,9 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.plugin.manager.base import BasePluginManager
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
class PluginDebuggingManager(BasePluginManager):
|
class PluginDebuggingClient(BasePluginClient):
|
||||||
def get_debugging_key(self, tenant_id: str) -> str:
|
def get_debugging_key(self, tenant_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get the debugging key for the given tenant.
|
Get the debugging key for the given tenant.
|
||||||
@ -1,8 +1,8 @@
|
|||||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||||
from core.plugin.manager.base import BasePluginManager
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
class PluginEndpointManager(BasePluginManager):
|
class PluginEndpointClient(BasePluginClient):
|
||||||
def create_endpoint(
|
def create_endpoint(
|
||||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from werkzeug import Request
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse, PluginOAuthCredentialsResponse
|
||||||
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthHandler(BasePluginClient):
|
||||||
|
def get_authorization_url(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
system_credentials: Mapping[str, Any],
|
||||||
|
) -> PluginOAuthAuthorizationUrlResponse:
|
||||||
|
return self._request_with_plugin_daemon_response(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||||
|
PluginOAuthAuthorizationUrlResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"system_credentials": system_credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
system_credentials: Mapping[str, Any],
|
||||||
|
request: Request,
|
||||||
|
) -> PluginOAuthCredentialsResponse:
|
||||||
|
"""
|
||||||
|
Get credentials from the given request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# encode request to raw http request
|
||||||
|
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||||
|
PluginOAuthCredentialsResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"system_credentials": system_credentials,
|
||||||
|
"raw_request_bytes": raw_request_bytes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||||
|
"""
|
||||||
|
Convert a Request object to raw HTTP data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The Request object to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The raw HTTP data as bytes.
|
||||||
|
"""
|
||||||
|
# Start with the request line
|
||||||
|
method = request.method
|
||||||
|
path = request.path
|
||||||
|
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||||
|
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||||
|
|
||||||
|
# Add headers
|
||||||
|
for header_name, header_value in request.headers.items():
|
||||||
|
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||||
|
|
||||||
|
# Add empty line to separate headers from body
|
||||||
|
raw_data += b"\r\n"
|
||||||
|
|
||||||
|
# Add body if exists
|
||||||
|
body = request.get_data(as_text=False)
|
||||||
|
if body:
|
||||||
|
raw_data += body
|
||||||
|
|
||||||
|
return raw_data
|
||||||
@ -0,0 +1,243 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import psycopg2.extras # type: ignore
|
||||||
|
import psycopg2.pool # type: ignore
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class VastbaseVectorConfig(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
database: str
|
||||||
|
min_connection: int
|
||||||
|
max_connection: int
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values["host"]:
|
||||||
|
raise ValueError("config VASTBASE_HOST is required")
|
||||||
|
if not values["port"]:
|
||||||
|
raise ValueError("config VASTBASE_PORT is required")
|
||||||
|
if not values["user"]:
|
||||||
|
raise ValueError("config VASTBASE_USER is required")
|
||||||
|
if not values["password"]:
|
||||||
|
raise ValueError("config VASTBASE_PASSWORD is required")
|
||||||
|
if not values["database"]:
|
||||||
|
raise ValueError("config VASTBASE_DATABASE is required")
|
||||||
|
if not values["min_connection"]:
|
||||||
|
raise ValueError("config VASTBASE_MIN_CONNECTION is required")
|
||||||
|
if not values["max_connection"]:
|
||||||
|
raise ValueError("config VASTBASE_MAX_CONNECTION is required")
|
||||||
|
if values["min_connection"] > values["max_connection"]:
|
||||||
|
raise ValueError("config VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
SQL_CREATE_TABLE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
text TEXT NOT NULL,
|
||||||
|
meta JSONB NOT NULL,
|
||||||
|
embedding floatvector({dimension}) NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
SQL_CREATE_INDEX = """
|
||||||
|
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||||
|
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class VastbaseVector(BaseVector):
|
||||||
|
def __init__(self, collection_name: str, config: VastbaseVectorConfig):
|
||||||
|
super().__init__(collection_name)
|
||||||
|
self.pool = self._create_connection_pool(config)
|
||||||
|
self.table_name = f"embedding_{collection_name}"
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return VectorType.VASTBASE
|
||||||
|
|
||||||
|
def _create_connection_pool(self, config: VastbaseVectorConfig):
|
||||||
|
return psycopg2.pool.SimpleConnectionPool(
|
||||||
|
config.min_connection,
|
||||||
|
config.max_connection,
|
||||||
|
host=config.host,
|
||||||
|
port=config.port,
|
||||||
|
user=config.user,
|
||||||
|
password=config.password,
|
||||||
|
database=config.database,
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _get_cursor(self):
|
||||||
|
conn = self.pool.getconn()
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
yield cur
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
conn.commit()
|
||||||
|
self.pool.putconn(conn)
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
dimension = len(embeddings[0])
|
||||||
|
self._create_collection(dimension)
|
||||||
|
return self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
values = []
|
||||||
|
pks = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
if doc.metadata is not None:
|
||||||
|
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||||
|
pks.append(doc_id)
|
||||||
|
values.append(
|
||||||
|
(
|
||||||
|
doc_id,
|
||||||
|
doc.page_content,
|
||||||
|
json.dumps(doc.metadata),
|
||||||
|
embeddings[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
psycopg2.extras.execute_values(
|
||||||
|
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||||
|
)
|
||||||
|
return pks
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||||
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
|
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||||
|
docs = []
|
||||||
|
for record in cur:
|
||||||
|
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||||
|
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||||
|
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Search the nearest neighbors to a vector.
|
||||||
|
|
||||||
|
:param query_vector: The input vector to search for similar items.
|
||||||
|
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||||
|
:return: List of Documents that are nearest to the query vector.
|
||||||
|
"""
|
||||||
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||||
|
f" ORDER BY distance LIMIT {top_k}",
|
||||||
|
(json.dumps(query_vector),),
|
||||||
|
)
|
||||||
|
docs = []
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
for record in cur:
|
||||||
|
metadata, text, distance = record
|
||||||
|
score = 1 - distance
|
||||||
|
metadata["score"] = score
|
||||||
|
if score > score_threshold:
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
top_k = kwargs.get("top_k", 5)
|
||||||
|
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||||
|
FROM {self.table_name}
|
||||||
|
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT {top_k}""",
|
||||||
|
# f"'{query}'" is required in order to account for whitespace in query
|
||||||
|
(f"'{query}'", f"'{query}'"),
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
for record in cur:
|
||||||
|
metadata, text, score = record
|
||||||
|
metadata["score"] = score
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
def _create_collection(self, dimension: int):
|
||||||
|
cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
lock_name = f"{cache_key}_lock"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
|
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
||||||
|
if dimension <= 16000:
|
||||||
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
|
||||||
|
class VastbaseVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VASTBASE, collection_name))
|
||||||
|
|
||||||
|
return VastbaseVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=VastbaseVectorConfig(
|
||||||
|
host=dify_config.VASTBASE_HOST or "localhost",
|
||||||
|
port=dify_config.VASTBASE_PORT,
|
||||||
|
user=dify_config.VASTBASE_USER or "dify",
|
||||||
|
password=dify_config.VASTBASE_PASSWORD or "",
|
||||||
|
database=dify_config.VASTBASE_DATABASE or "dify",
|
||||||
|
min_connection=dify_config.VASTBASE_MIN_CONNECTION,
|
||||||
|
max_connection=dify_config.VASTBASE_MAX_CONNECTION,
|
||||||
|
),
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue