Merge branch 'main' into e-300
commit
62347206c0
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class VastbaseVectorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Vector (Vastbase with vector extension)
|
||||
"""
|
||||
|
||||
VASTBASE_HOST: Optional[str] = Field(
|
||||
description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_PORT: PositiveInt = Field(
|
||||
description="Port number on which the Vastbase server is listening (default is 5432)",
|
||||
default=5432,
|
||||
)
|
||||
|
||||
VASTBASE_USER: Optional[str] = Field(
|
||||
description="Username for authenticating with the Vastbase database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with the Vastbase database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_DATABASE: Optional[str] = Field(
|
||||
description="Name of the Vastbase database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Min connection of the Vastbase database",
|
||||
default=1,
|
||||
)
|
||||
|
||||
VASTBASE_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Max connection of the Vastbase database",
|
||||
default=5,
|
||||
)
|
||||
@ -0,0 +1,97 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class WeaveTokenUsage(BaseModel):
|
||||
input_tokens: Optional[int] = None
|
||||
output_tokens: Optional[int] = None
|
||||
total_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class WeaveMultiModel(BaseModel):
|
||||
file_list: Optional[list[str]] = Field(None, description="List of files")
|
||||
|
||||
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
exception: Optional[str] = Field(None, description="Exception message of the trace")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get("input_tokens", 0),
|
||||
"output_tokens": values.get("output_tokens", 0),
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
return v
|
||||
return v
|
||||
@ -0,0 +1,420 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeaveDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
weave_config: WeaveConfig,
|
||||
):
|
||||
super().__init__(weave_config)
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
|
||||
# Login with API key first
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
|
||||
# Then initialize weave client
|
||||
self.weave_client = weave.init(
|
||||
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls: dict[str, Any] = {}
|
||||
|
||||
def get_project_url(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug(f"Weave get run url failed: {str(e)}")
|
||||
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.debug(f"Trace info: {trace_info}")
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
|
||||
if trace_info.message_id:
|
||||
message_attributes = trace_info.metadata
|
||||
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
message_attributes["message_id"] = trace_info.message_id
|
||||
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
message_attributes["trace_id"] = trace_id
|
||||
message_attributes["start_time"] = trace_info.start_time
|
||||
message_attributes["end_time"] = trace_info.end_time
|
||||
message_attributes["tags"] = ["message", "workflow"]
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_info.message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
total_tokens=trace_info.total_tokens,
|
||||
attributes=message_attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(message_run)
|
||||
|
||||
workflow_attributes = trace_info.metadata
|
||||
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
workflow_attributes["trace_id"] = trace_id
|
||||
workflow_attributes["start_time"] = trace_info.start_time
|
||||
workflow_attributes["end_time"] = trace_info.end_time
|
||||
workflow_attributes["tags"] = ["workflow"]
|
||||
|
||||
workflow_run = WeaveTraceModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE.value),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
attributes=workflow_attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
|
||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
attributes = execution_metadata.copy()
|
||||
attributes.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
"ls_provider": process_data.get("model_provider", ""),
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
attributes["tags"] = ["node_execution"]
|
||||
attributes["start_time"] = created_at
|
||||
attributes["end_time"] = finished_at
|
||||
attributes["elapsed_time"] = elapsed_time
|
||||
attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
attributes["trace_id"] = trace_id
|
||||
node_run = WeaveTraceModel(
|
||||
total_tokens=node_total_tokens,
|
||||
op=node_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
file_list=trace_info.file_list,
|
||||
attributes=attributes,
|
||||
id=node_execution_id,
|
||||
exception=None,
|
||||
)
|
||||
|
||||
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(node_run)
|
||||
|
||||
self.finish_call(workflow_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
attributes = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
attributes["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
attributes["end_user_id"] = end_user_id
|
||||
|
||||
attributes["message_id"] = message_id
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
|
||||
message_run = WeaveTraceModel(
|
||||
id=message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
exception=trace_info.error,
|
||||
file_list=file_list,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.start_call(message_run)
|
||||
|
||||
# create llm run parented to message run
|
||||
llm_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
op="llm",
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
file_list=[],
|
||||
exception=None,
|
||||
)
|
||||
self.start_call(
|
||||
llm_run,
|
||||
parent_run_id=message_id,
|
||||
)
|
||||
self.finish_call(llm_run)
|
||||
self.finish_call(message_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["moderation"]
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
||||
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE.value),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(moderation_run, parent_run_id=trace_info.message_id)
|
||||
self.finish_call(moderation_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["suggested_question"]
|
||||
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
|
||||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
|
||||
self.finish_call(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["dataset_retrieval"]
|
||||
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
||||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
|
||||
self.finish_call(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["tool", trace_info.tool_name]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
tool_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=trace_info.tool_name,
|
||||
inputs=trace_info.tool_inputs,
|
||||
outputs=trace_info.tool_outputs,
|
||||
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||
message_id = message_id or None
|
||||
self.start_call(tool_run, parent_run_id=message_id)
|
||||
self.finish_call(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["generate_name"]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
name_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(name_run)
|
||||
self.finish_call(name_run)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
print("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Weave API check failed: {str(e)}")
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
|
||||
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data.id].parent_id = parent_run_id
|
||||
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call = self.calls.get(run_data.id)
|
||||
if call:
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
|
||||
else:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
@ -1,7 +1,7 @@
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginAssetManager(BasePluginManager):
|
||||
class PluginAssetManager(BasePluginClient):
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
"""
|
||||
Fetch an asset by id.
|
||||
@ -1,9 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginDebuggingManager(BasePluginManager):
|
||||
class PluginDebuggingClient(BasePluginClient):
|
||||
def get_debugging_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get the debugging key for the given tenant.
|
||||
@ -1,8 +1,8 @@
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginEndpointManager(BasePluginManager):
|
||||
class PluginEndpointClient(BasePluginClient):
|
||||
def create_endpoint(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||
) -> bool:
|
||||
@ -0,0 +1,98 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from werkzeug import Request
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse, PluginOAuthCredentialsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class OAuthHandler(BasePluginClient):
|
||||
def get_authorization_url(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
) -> PluginOAuthAuthorizationUrlResponse:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||
PluginOAuthAuthorizationUrlResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"system_credentials": system_credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
request: Request,
|
||||
) -> PluginOAuthCredentialsResponse:
|
||||
"""
|
||||
Get credentials from the given request.
|
||||
"""
|
||||
|
||||
# encode request to raw http request
|
||||
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||
PluginOAuthCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"system_credentials": system_credentials,
|
||||
"raw_request_bytes": raw_request_bytes,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||
"""
|
||||
Convert a Request object to raw HTTP data.
|
||||
|
||||
Args:
|
||||
request: The Request object to convert.
|
||||
|
||||
Returns:
|
||||
The raw HTTP data as bytes.
|
||||
"""
|
||||
# Start with the request line
|
||||
method = request.method
|
||||
path = request.path
|
||||
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||
|
||||
# Add headers
|
||||
for header_name, header_value in request.headers.items():
|
||||
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||
|
||||
# Add empty line to separate headers from body
|
||||
raw_data += b"\r\n"
|
||||
|
||||
# Add body if exists
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw_data += body
|
||||
|
||||
return raw_data
|
||||
@ -0,0 +1,243 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class VastbaseVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config VASTBASE_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config VASTBASE_PORT is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config VASTBASE_USER is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config VASTBASE_PASSWORD is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config VASTBASE_DATABASE is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config VASTBASE_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config VASTBASE_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding floatvector({dimension}) NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
class VastbaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: VastbaseVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.VASTBASE
|
||||
|
||||
def _create_connection_pool(self, config: VastbaseVectorConfig):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
config.min_connection,
|
||||
config.max_connection,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||
)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
for record in cur:
|
||||
metadata, text, score = record
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
||||
if dimension <= 16000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class VastbaseVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VASTBASE, collection_name))
|
||||
|
||||
return VastbaseVector(
|
||||
collection_name=collection_name,
|
||||
config=VastbaseVectorConfig(
|
||||
host=dify_config.VASTBASE_HOST or "localhost",
|
||||
port=dify_config.VASTBASE_PORT,
|
||||
user=dify_config.VASTBASE_USER or "dify",
|
||||
password=dify_config.VASTBASE_PASSWORD or "",
|
||||
database=dify_config.VASTBASE_DATABASE or "dify",
|
||||
min_connection=dify_config.VASTBASE_MIN_CONNECTION,
|
||||
max_connection=dify_config.VASTBASE_MAX_CONNECTION,
|
||||
),
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue