chore: merge main and add docs
commit
a408459875
@ -1,67 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from constants import IMAGE_EXTENSIONS
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file_id: the id of UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -1,22 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Any, Any] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
else:
|
||||
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: Any, value: Any) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.capacity:
|
||||
self.cache.popitem(last=False) # pop the first item
|
||||
@ -0,0 +1,487 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_SYSTEM,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AliyunDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
aliyun_config: AliyunConfig,
|
||||
):
|
||||
super().__init__(aliyun_config)
|
||||
base_url = aliyun_config.endpoint.rstrip("/")
|
||||
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self):
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.trace_client.get_project_url()
|
||||
except Exception as e:
|
||||
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
|
||||
pre_prompt = getattr(app_model_config, "pre_prompt", "")
|
||||
inputs_data = getattr(trace_info.message_data, "inputs", {})
|
||||
llm_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=convert_to_span_id(message_id, "llm"),
|
||||
name="llm",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(trace_info.outputs),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: str(trace_info.inputs),
|
||||
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||
INPUT_VALUE: str(trace_info.inputs),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
return workflow_node_executions
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
return node_span
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||
span_status: Status = Status(StatusCode.UNSET)
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
span_status = Status(StatusCode.OK)
|
||||
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
span_status = Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return span_status
|
||||
|
||||
def build_workflow_task_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_tool_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_retrieval_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
input_value = ""
|
||||
if node_execution.inputs:
|
||||
input_value = str(node_execution.inputs.get("query", ""))
|
||||
output_value = ""
|
||||
if node_execution.outputs:
|
||||
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: output_value,
|
||||
INPUT_VALUE: input_value,
|
||||
OUTPUT_VALUE: output_value,
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_llm_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
if message_span_id: # chatflow
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
workflow_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_id = trace_info.message_id
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]):
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
@ -0,0 +1,200 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TraceClient:
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
max_queue_size: int = 1000,
|
||||
schedule_delay_sec: int = 5,
|
||||
max_export_batch_size: int = 50,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
self.exporter = OTLPSpanExporter(endpoint=endpoint)
|
||||
|
||||
self.max_queue_size = max_queue_size
|
||||
self.schedule_delay_sec = schedule_delay_sec
|
||||
self.max_export_batch_size = max_export_batch_size
|
||||
|
||||
self.queue: deque = deque(maxlen=max_queue_size)
|
||||
self.condition = threading.Condition(threading.Lock())
|
||||
self.done = False
|
||||
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
self._spans_dropped = False
|
||||
|
||||
def export(self, spans: Sequence[ReadableSpan]):
|
||||
self.exporter.export(spans)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.debug(f"AliyunTrace API check failed: {str(e)}")
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
return "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
def add_span(self, span_data: SpanData):
|
||||
if span_data is None:
|
||||
return
|
||||
span: ReadableSpan = self.span_builder.build_span(span_data)
|
||||
with self.condition:
|
||||
if len(self.queue) == self.max_queue_size:
|
||||
if not self._spans_dropped:
|
||||
logger.warning("Queue is full, likely spans will be dropped.")
|
||||
self._spans_dropped = True
|
||||
|
||||
self.queue.appendleft(span)
|
||||
if len(self.queue) >= self.max_export_batch_size:
|
||||
self.condition.notify()
|
||||
|
||||
def _worker(self):
|
||||
while not self.done:
|
||||
with self.condition:
|
||||
if len(self.queue) < self.max_export_batch_size and not self.done:
|
||||
self.condition.wait(timeout=self.schedule_delay_sec)
|
||||
self._export_batch()
|
||||
|
||||
def _export_batch(self):
|
||||
spans_to_export: list[ReadableSpan] = []
|
||||
with self.condition:
|
||||
while len(spans_to_export) < self.max_export_batch_size and self.queue:
|
||||
spans_to_export.append(self.queue.pop())
|
||||
|
||||
if spans_to_export:
|
||||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error exporting spans: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
with self.condition:
|
||||
self.done = True
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self._export_batch()
|
||||
self.exporter.shutdown()
|
||||
|
||||
|
||||
class SpanBuilder:
|
||||
def __init__(self, resource):
|
||||
self.resource = resource
|
||||
self.instrumentation_scope = InstrumentationScope(
|
||||
__name__,
|
||||
"",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def build_span(self, span_data: SpanData) -> ReadableSpan:
|
||||
span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
parent_span_context = None
|
||||
if span_data.parent_span_id is not None:
|
||||
parent_span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.parent_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
span = ReadableSpan(
|
||||
name=span_data.name,
|
||||
context=span_context,
|
||||
parent=parent_span_context,
|
||||
resource=self.resource,
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
instrumentation_scope=self.instrumentation_scope,
|
||||
)
|
||||
return span
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
|
||||
return timestamp_in_nanoseconds
|
||||
@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")
|
||||
@ -0,0 +1,64 @@
|
||||
from enum import Enum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
|
||||
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||
|
||||
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||
|
||||
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||
|
||||
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||
|
||||
|
||||
# Chain
|
||||
INPUT_VALUE = "input.value"
|
||||
|
||||
OUTPUT_VALUE = "output.value"
|
||||
|
||||
|
||||
# Retriever
|
||||
RETRIEVAL_QUERY = "retrieval.query"
|
||||
|
||||
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||
|
||||
|
||||
# LLM
|
||||
GEN_AI_MODEL_NAME = "gen_ai.model_name"
|
||||
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
|
||||
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||
|
||||
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||
|
||||
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
LLM = "LLM"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
TOOL = "TOOL"
|
||||
AGENT = "AGENT"
|
||||
TASK = "TASK"
|
||||
@ -0,0 +1,726 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = GrpcOTLPSpanExporter(
|
||||
endpoint=arize_endpoint,
|
||||
headers=arize_headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=phoenix_endpoint,
|
||||
headers=phoenix_headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
attributes = {
|
||||
"openinference.project.name": arize_phoenix_config.project or "",
|
||||
"model_id": arize_phoenix_config.project or "",
|
||||
}
|
||||
resource = Resource(attributes=attributes)
|
||||
provider = trace_sdk.TracerProvider(resource=resource)
|
||||
processor = SimpleSpanProcessor(
|
||||
exporter,
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# Create a named tracer instead of setting the global provider
|
||||
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
||||
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
|
||||
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def datetime_to_nanos(dt: Optional[datetime]) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def uuid_to_trace_id(string: Optional[str]) -> int:
|
||||
"""Convert UUID string to a valid trace ID (16-byte integer)."""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash
|
||||
digest = hash_object.digest()[:16]
|
||||
|
||||
# Convert to integer (128 bits)
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
workflow_metadata = {
|
||||
"workflow_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM.value
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_metadata["ls_provider"] = provider
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
|
||||
if model:
|
||||
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
|
||||
)
|
||||
finally:
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_metadata = {
|
||||
"message_id": trace_info.message_id or "",
|
||||
"conversation_mode": str(trace_info.conversation_mode or ""),
|
||||
"user_id": trace_info.message_data.from_account_id or "",
|
||||
"file_list": json.dumps(file_list),
|
||||
"status": trace_info.message_data.status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
"prompt_tokens": trace_info.message_tokens or 0,
|
||||
"completion_tokens": trace_info.answer_tokens or 0,
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
message_metadata.update(trace_info.metadata)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
if isinstance(trace_info.inputs, list):
|
||||
for i, msg in enumerate(trace_info.inputs):
|
||||
if isinstance(msg, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
|
||||
"role", "user"
|
||||
)
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(trace_info.inputs, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(trace_info.inputs, str):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
|
||||
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
|
||||
|
||||
if trace_info.message_data.model_id is not None:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
|
||||
if trace_info.message_data.model_provider is not None:
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = json.loads(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "moderation",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(
|
||||
{
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "suggested_question",
|
||||
"status": trace_info.status,
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"ls_provider": trace_info.model_provider or "",
|
||||
"ls_model_name": trace_info.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "dataset_retrieval",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
"start_time": start_time.isoformat() if start_time else "",
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
if isinstance(trace_info.tool_parameters, dict)
|
||||
else str(trace_info.tool_parameters)
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"project_name": self.project,
|
||||
"message_id": trace_info.message_id,
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
span.set_attribute("test", "true")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
|
||||
return "https://app.arize.com/"
|
||||
else:
|
||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
@ -0,0 +1,280 @@
|
||||
import base64
|
||||
import binascii
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.encrypter import (
|
||||
batch_decrypt_token,
|
||||
decrypt_token,
|
||||
encrypt_token,
|
||||
get_decrypt_decoding,
|
||||
obfuscated_token,
|
||||
)
|
||||
from libs.rsa import PrivkeyNotFoundError
|
||||
|
||||
|
||||
class TestObfuscatedToken:
|
||||
@pytest.mark.parametrize(
|
||||
("token", "expected"),
|
||||
[
|
||||
("", ""), # Empty token
|
||||
("1234567", "*" * 20), # Short token (<8 chars)
|
||||
("12345678", "*" * 20), # Boundary case (8 chars)
|
||||
("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
|
||||
("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
|
||||
],
|
||||
)
|
||||
def test_obfuscation_logic(self, token, expected):
|
||||
"""Test core obfuscation logic for various token lengths"""
|
||||
assert obfuscated_token(token) == expected
|
||||
|
||||
def test_sensitive_data_protection(self):
|
||||
"""Ensure obfuscation never reveals full sensitive data"""
|
||||
token = "api_key_secret_12345"
|
||||
obfuscated = obfuscated_token(token)
|
||||
assert token not in obfuscated
|
||||
assert "*" * 12 in obfuscated
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
|
||||
assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDecryptToken:
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_successful_decryption(self, mock_decrypt):
|
||||
"""Test successful token decryption"""
|
||||
mock_decrypt.return_value = "decrypted_token"
|
||||
encrypted_data = base64.b64encode(b"encrypted_data").decode()
|
||||
|
||||
result = decrypt_token("tenant-123", encrypted_data)
|
||||
|
||||
assert result == "decrypted_token"
|
||||
mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
|
||||
|
||||
def test_invalid_base64(self):
|
||||
"""Test handling of invalid base64 input"""
|
||||
with pytest.raises(binascii.Error):
|
||||
decrypt_token("tenant-123", "invalid_base64!!!")
|
||||
|
||||
|
||||
class TestBatchDecryptToken:
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Test batch decryption functionality"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
|
||||
tokens = [
|
||||
base64.b64encode(b"encrypted1").decode(),
|
||||
base64.b64encode(b"encrypted2").decode(),
|
||||
base64.b64encode(b"encrypted3").decode(),
|
||||
]
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3"]
|
||||
# Key should only be loaded once
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
|
||||
|
||||
class TestGetDecryptDecoding:
|
||||
@patch("extensions.ext_redis.redis_client.get")
|
||||
@patch("extensions.ext_storage.storage.load")
|
||||
def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
|
||||
"""Test error when private key file doesn't exist"""
|
||||
mock_redis_get.return_value = None
|
||||
mock_storage_load.side_effect = FileNotFoundError()
|
||||
|
||||
with pytest.raises(PrivkeyNotFoundError) as exc_info:
|
||||
get_decrypt_decoding("tenant-123")
|
||||
|
||||
assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
mock_decrypt.return_value = original_token
|
||||
|
||||
# Test encryption
|
||||
encrypted = encrypt_token("tenant-123", original_token)
|
||||
|
||||
# Test decryption
|
||||
decrypted = decrypt_token("tenant-123", encrypted)
|
||||
|
||||
assert decrypted == original_token
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
encrypted = encrypt_token("tenant-123", "sensitive_data")
|
||||
|
||||
# Attempt to decrypt with different tenant should fail
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
mock_decrypt.side_effect = Exception("Invalid tenant key")
|
||||
|
||||
with pytest.raises(Exception, match="Invalid tenant key"):
|
||||
decrypt_token("different-tenant", encrypted)
|
||||
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_tampered_ciphertext_rejection(self, mock_decrypt):
|
||||
"""Detect and reject tampered ciphertext"""
|
||||
valid_encrypted = base64.b64encode(b"valid_data").decode()
|
||||
|
||||
# Tamper with ciphertext
|
||||
tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
|
||||
tampered_bytes[0] ^= 0xFF
|
||||
tampered = base64.b64encode(bytes(tampered_bytes)).decode()
|
||||
|
||||
mock_decrypt.side_effect = Exception("Decryption error")
|
||||
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
|
||||
results = [encrypt_token("tenant-123", "token") for _ in range(3)]
|
||||
|
||||
# All results should be different
|
||||
assert len(set(results)) == 3
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Additional security-focused edge case tests"""
|
||||
|
||||
def test_should_handle_empty_string_in_obfuscation(self):
|
||||
"""Test handling of empty string in obfuscation"""
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
special_tokens = [
|
||||
"token\x00with\x00null", # Null bytes
|
||||
"token_with_emoji_😀🎉", # Unicode emoji
|
||||
"token\nwith\nnewlines", # Newlines
|
||||
"token\twith\ttabs", # Tabs
|
||||
"token_with_中文字符", # Chinese characters
|
||||
]
|
||||
|
||||
for token in special_tokens:
|
||||
result = encrypt_token("tenant-123", token)
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
|
||||
|
||||
# Create a token that would exceed RSA limits
|
||||
long_token = "x" * 300
|
||||
|
||||
with pytest.raises(ValueError, match="Message too long for RSA key size"):
|
||||
encrypt_token("tenant-123", long_token)
|
||||
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Verify batch decryption optimization - loads key only once"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test with multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
|
||||
tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
|
||||
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3", "token4", "token5"]
|
||||
# Key should only be loaded once regardless of token count
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
assert mock_decrypt_with_decoding.call_count == 5
|
||||
@ -0,0 +1,194 @@
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.url_signer import SignedUrlParams, UrlSigner
|
||||
|
||||
|
||||
class TestUrlSigner:
|
||||
"""Test cases for UrlSigner class"""
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_generate_signed_url_params(self):
|
||||
"""Test generation of signed URL parameters with all required fields"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
# Verify the returned object and required fields
|
||||
assert isinstance(params, SignedUrlParams)
|
||||
assert params.sign_key == sign_key
|
||||
assert params.timestamp is not None
|
||||
assert params.nonce is not None
|
||||
assert params.sign is not None
|
||||
|
||||
# Verify nonce format (32 character hex string)
|
||||
assert len(params.nonce) == 32
|
||||
assert all(c in "0123456789abcdef" for c in params.nonce)
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_generate_complete_signed_url(self):
|
||||
"""Test generation of complete signed URL with query parameters"""
|
||||
base_url = "https://example.com/api/test"
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
signed_url = UrlSigner.get_signed_url(base_url, sign_key, prefix)
|
||||
|
||||
# Parse URL and verify structure
|
||||
parsed = urlparse(signed_url)
|
||||
assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == base_url
|
||||
|
||||
# Verify query parameters
|
||||
query_params = parse_qs(parsed.query)
|
||||
assert "timestamp" in query_params
|
||||
assert "nonce" in query_params
|
||||
assert "sign" in query_params
|
||||
|
||||
# Verify each parameter has exactly one value
|
||||
assert len(query_params["timestamp"]) == 1
|
||||
assert len(query_params["nonce"]) == 1
|
||||
assert len(query_params["sign"]) == 1
|
||||
|
||||
# Verify parameter values are not empty
|
||||
assert query_params["timestamp"][0]
|
||||
assert query_params["nonce"][0]
|
||||
assert query_params["sign"][0]
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_verify_valid_signature(self):
|
||||
"""Test verification of valid signature"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate and verify signature
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
is_valid = UrlSigner.verify(
|
||||
sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=prefix
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
@pytest.mark.parametrize(
|
||||
("field", "modifier"),
|
||||
[
|
||||
("sign_key", lambda _: "wrong-sign-key"),
|
||||
("timestamp", lambda t: str(int(t) + 1000)),
|
||||
("nonce", lambda _: "different-nonce-123456789012345"),
|
||||
("prefix", lambda _: "wrong-prefix"),
|
||||
("sign", lambda s: s + "tampered"),
|
||||
],
|
||||
)
|
||||
def test_should_reject_invalid_signature_params(self, field, modifier):
|
||||
"""Test signature verification rejects invalid parameters"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate valid signed parameters
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
# Prepare verification parameters
|
||||
verify_params = {
|
||||
"sign_key": sign_key,
|
||||
"timestamp": params.timestamp,
|
||||
"nonce": params.nonce,
|
||||
"sign": params.sign,
|
||||
"prefix": prefix,
|
||||
}
|
||||
|
||||
# Modify the specific field
|
||||
verify_params[field] = modifier(verify_params[field])
|
||||
|
||||
# Verify should fail
|
||||
is_valid = UrlSigner.verify(**verify_params)
|
||||
assert is_valid is False
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", None)
|
||||
def test_should_raise_error_without_secret_key(self):
|
||||
"""Test that signing fails when SECRET_KEY is not configured"""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
UrlSigner.get_signed_url_params("key", "prefix")
|
||||
|
||||
assert "SECRET_KEY is not set" in str(exc_info.value)
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_generate_unique_signatures(self):
|
||||
"""Test that different inputs produce different signatures"""
|
||||
params1 = UrlSigner.get_signed_url_params("key1", "prefix1")
|
||||
params2 = UrlSigner.get_signed_url_params("key2", "prefix2")
|
||||
|
||||
# Different inputs should produce different signatures
|
||||
assert params1.sign != params2.sign
|
||||
assert params1.nonce != params2.nonce
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_handle_special_characters(self):
|
||||
"""Test handling of special characters in parameters"""
|
||||
special_cases = [
|
||||
"test with spaces",
|
||||
"test/with/slashes",
|
||||
"test中文字符",
|
||||
]
|
||||
|
||||
for sign_key in special_cases:
|
||||
params = UrlSigner.get_signed_url_params(sign_key, "prefix")
|
||||
|
||||
# Should generate valid signature and verify correctly
|
||||
is_valid = UrlSigner.verify(
|
||||
sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix="prefix"
|
||||
)
|
||||
assert is_valid is True
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_ensure_nonce_randomness(self):
|
||||
"""Test that nonce is random for each generation - critical for security"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate multiple nonces
|
||||
nonces = set()
|
||||
for _ in range(5):
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
nonces.add(params.nonce)
|
||||
|
||||
# All nonces should be unique
|
||||
assert len(nonces) == 5
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
@patch("time.time", return_value=1234567890)
|
||||
@patch("os.urandom", return_value=b"\xab\xcd\xef\x12\x34\x56\x78\x90\xab\xcd\xef\x12\x34\x56\x78\x90")
|
||||
def test_should_produce_consistent_signatures(self, mock_urandom, mock_time):
|
||||
"""Test that same inputs produce same signature - ensures deterministic behavior"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate signature multiple times with same inputs (time and nonce are mocked)
|
||||
params1 = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
params2 = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
# With mocked time and random, should produce identical results
|
||||
assert params1.timestamp == params2.timestamp
|
||||
assert params1.nonce == params2.nonce
|
||||
assert params1.sign == params2.sign
|
||||
|
||||
# Verify the signature is valid
|
||||
assert UrlSigner.verify(
|
||||
sign_key=sign_key, timestamp=params1.timestamp, nonce=params1.nonce, sign=params1.sign, prefix=prefix
|
||||
)
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_handle_empty_strings(self):
|
||||
"""Test handling of empty string parameters - common edge case"""
|
||||
# Empty sign_key and prefix should still work
|
||||
params = UrlSigner.get_signed_url_params("", "")
|
||||
assert params.sign is not None
|
||||
|
||||
# Should verify correctly
|
||||
is_valid = UrlSigner.verify(
|
||||
sign_key="", timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=""
|
||||
)
|
||||
assert is_valid is True
|
||||
@ -0,0 +1 @@
|
||||
# Unit tests for core ops module
|
||||
@ -0,0 +1,385 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.entities.config_entity import (
|
||||
AliyunConfig,
|
||||
ArizeConfig,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
PhoenixConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestTracingProviderEnum:
|
||||
"""Test cases for TracingProviderEnum"""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Test that all expected enum values are present"""
|
||||
assert TracingProviderEnum.ARIZE == "arize"
|
||||
assert TracingProviderEnum.PHOENIX == "phoenix"
|
||||
assert TracingProviderEnum.LANGFUSE == "langfuse"
|
||||
assert TracingProviderEnum.LANGSMITH == "langsmith"
|
||||
assert TracingProviderEnum.OPIK == "opik"
|
||||
assert TracingProviderEnum.WEAVE == "weave"
|
||||
assert TracingProviderEnum.ALIYUN == "aliyun"
|
||||
|
||||
|
||||
class TestArizeConfig:
|
||||
"""Test cases for ArizeConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Arize configuration"""
|
||||
config = ArizeConfig(
|
||||
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.space_id == "test_space"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = ArizeConfig()
|
||||
assert config.api_key is None
|
||||
assert config.space_id is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = ArizeConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_project_validation_none(self):
|
||||
"""Test project validation with None value"""
|
||||
config = ArizeConfig(project=None)
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = ArizeConfig(endpoint="")
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="ftp://invalid.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="invalid.com")
|
||||
|
||||
|
||||
class TestPhoenixConfig:
|
||||
"""Test cases for PhoenixConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Phoenix configuration"""
|
||||
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = PhoenixConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = PhoenixConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
"""Test cases for LangfuseConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Langfuse configuration"""
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
|
||||
class TestLangSmithConfig:
|
||||
"""Test cases for LangSmithConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid LangSmith configuration"""
|
||||
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.smith.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangSmithConfig(api_key="key", project="project")
|
||||
assert config.endpoint == "https://api.smith.langchain.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
|
||||
|
||||
|
||||
class TestOpikConfig:
|
||||
"""Test cases for OpikConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Opik configuration"""
|
||||
config = OpikConfig(
|
||||
api_key="test_key",
|
||||
project="test_project",
|
||||
workspace="test_workspace",
|
||||
url="https://custom.comet.com/opik/api/",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.workspace == "test_workspace"
|
||||
assert config.url == "https://custom.comet.com/opik/api/"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = OpikConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.workspace is None
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = OpikConfig(project="")
|
||||
assert config.project == "Default Project"
|
||||
|
||||
def test_url_validation_empty(self):
|
||||
"""Test URL validation with empty value"""
|
||||
config = OpikConfig(url="")
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_url_validation_missing_suffix(self):
|
||||
"""Test URL validation requires /api/ suffix"""
|
||||
with pytest.raises(ValidationError, match="URL should end with /api/"):
|
||||
OpikConfig(url="https://custom.comet.com/opik/")
|
||||
|
||||
def test_url_validation_invalid_scheme(self):
|
||||
"""Test URL validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Aliyun configuration"""
|
||||
config = AliyunConfig(
|
||||
app_name="test_app",
|
||||
license_key="test_license_key",
|
||||
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
|
||||
)
|
||||
assert config.app_name == "test_app"
|
||||
assert config.license_key == "test_license_key"
|
||||
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="")
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_license_key_required(self):
|
||||
"""Test that license_key is required and cannot be empty"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Integration tests for configuration classes"""
|
||||
|
||||
def test_all_configs_can_be_instantiated(self):
|
||||
"""Test that all config classes can be instantiated with valid data"""
|
||||
configs = [
|
||||
ArizeConfig(api_key="key"),
|
||||
PhoenixConfig(api_key="key"),
|
||||
LangfuseConfig(public_key="public", secret_key="secret"),
|
||||
LangSmithConfig(api_key="key", project="project"),
|
||||
OpikConfig(api_key="key"),
|
||||
WeaveConfig(api_key="key", project="project"),
|
||||
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
assert config is not None
|
||||
|
||||
def test_url_normalization_consistency(self):
|
||||
"""Test that URL normalization works consistently across configs"""
|
||||
# Test that paths are removed from endpoints
|
||||
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
|
||||
phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
|
||||
assert arize_config.endpoint == "https://arize.com"
|
||||
assert phoenix_config.endpoint == "https://phoenix.com"
|
||||
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_project_default_values(self):
|
||||
"""Test that project default values are set correctly"""
|
||||
arize_config = ArizeConfig(project="")
|
||||
phoenix_config = PhoenixConfig(project="")
|
||||
opik_config = OpikConfig(project="")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
|
||||
assert arize_config.project == "default"
|
||||
assert phoenix_config.project == "default"
|
||||
assert opik_config.project == "Default Project"
|
||||
assert aliyun_config.app_name == "dify_app"
|
||||
@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Test cases for validate_url function"""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
"""Test valid HTTPS URL"""
|
||||
result = validate_url("https://example.com", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_valid_http_url(self):
|
||||
"""Test valid HTTP URL"""
|
||||
result = validate_url("http://example.com", "https://default.com")
|
||||
assert result == "http://example.com"
|
||||
|
||||
def test_url_with_path_removed(self):
|
||||
"""Test that URL path is removed during normalization"""
|
||||
result = validate_url("https://example.com/api/v1/test", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_query_removed(self):
|
||||
"""Test that URL query parameters are removed"""
|
||||
result = validate_url("https://example.com?param=value", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_fragment_removed(self):
|
||||
"""Test that URL fragments are removed"""
|
||||
result = validate_url("https://example.com#section", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_whitespace_url_returns_default(self):
|
||||
"""Test whitespace URL returns default"""
|
||||
result = validate_url(" ", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("example.com", "https://default.com")
|
||||
|
||||
def test_custom_allowed_schemes(self):
|
||||
"""Test custom allowed schemes"""
|
||||
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
assert result == "https://example.com"
|
||||
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class TestValidateUrlWithPath:
|
||||
"""Test cases for validate_url_with_path function"""
|
||||
|
||||
def test_valid_url_with_path(self):
|
||||
"""Test valid URL with path"""
|
||||
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
|
||||
assert result == "https://example.com/api/v1"
|
||||
|
||||
def test_valid_url_with_required_suffix(self):
|
||||
"""Test valid URL with required suffix"""
|
||||
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
|
||||
assert result == "https://example.com/api/"
|
||||
|
||||
def test_url_without_required_suffix_raises_error(self):
|
||||
"""Test URL without required suffix raises error"""
|
||||
with pytest.raises(ValueError, match="URL should end with /api/"):
|
||||
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url_with_path("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url_with_path(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("example.com", "https://default.com")
|
||||
|
||||
|
||||
class TestValidateProjectName:
|
||||
"""Test cases for validate_project_name function"""
|
||||
|
||||
def test_valid_project_name(self):
|
||||
"""Test valid project name"""
|
||||
result = validate_project_name("my-project", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_empty_project_name_returns_default(self):
|
||||
"""Test empty project name returns default"""
|
||||
result = validate_project_name("", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_none_project_name_returns_default(self):
|
||||
"""Test None project name returns default"""
|
||||
result = validate_project_name(None, "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_whitespace_project_name_returns_default(self):
|
||||
"""Test whitespace project name returns default"""
|
||||
result = validate_project_name(" ", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_project_name_with_whitespace_trimmed(self):
|
||||
"""Test project name with whitespace is trimmed"""
|
||||
result = validate_project_name(" my-project ", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_custom_default_name(self):
|
||||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
@ -0,0 +1,53 @@
|
||||
from redis import RedisError
|
||||
|
||||
from extensions.ext_redis import redis_fallback
|
||||
|
||||
|
||||
def test_redis_fallback_success():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
assert test_func() == "success"
|
||||
|
||||
|
||||
def test_redis_fallback_error():
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() == "fallback"
|
||||
|
||||
|
||||
def test_redis_fallback_none_default():
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() is None
|
||||
|
||||
|
||||
def test_redis_fallback_with_args():
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(1, 2) == 0
|
||||
|
||||
|
||||
def test_redis_fallback_with_kwargs():
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class TestExtractTenantId:
|
||||
"""Test cases for the extract_tenant_id utility function."""
|
||||
|
||||
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
# Mock the current_tenant_id property
|
||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id == "account-tenant-123"
|
||||
|
||||
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account._current_tenant = None
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_from_enduser_with_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser with tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = "enduser-tenant-456"
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id == "enduser-tenant-456"
|
||||
|
||||
def test_extract_tenant_id_from_enduser_without_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser without tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = None
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_with_invalid_user_type(self):
|
||||
"""Test extracting tenant_id with invalid user type raises ValueError."""
|
||||
invalid_user = "not_a_user_object"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(invalid_user)
|
||||
|
||||
def test_extract_tenant_id_with_none_user(self):
|
||||
"""Test extracting tenant_id with None user raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(None)
|
||||
|
||||
def test_extract_tenant_id_with_dict_user(self):
|
||||
"""Test extracting tenant_id with dict user raises ValueError."""
|
||||
dict_user = {"id": "123", "tenant_id": "456"}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
@ -0,0 +1,74 @@
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
|
||||
|
||||
class TestValidPassword:
|
||||
"""Test password format validation"""
|
||||
|
||||
def test_should_accept_valid_passwords(self):
|
||||
"""Test accepting valid password formats"""
|
||||
assert valid_password("password123") == "password123"
|
||||
assert valid_password("test1234") == "test1234"
|
||||
assert valid_password("Test123456") == "Test123456"
|
||||
|
||||
def test_should_reject_invalid_passwords(self):
|
||||
"""Test rejecting invalid password formats"""
|
||||
# Too short
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
valid_password("abc123")
|
||||
assert "Password must contain letters and numbers" in str(exc_info.value)
|
||||
|
||||
# No numbers
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("abcdefgh")
|
||||
|
||||
# No letters
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("12345678")
|
||||
|
||||
# Empty
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("")
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Test password hashing and comparison"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test data"""
|
||||
self.password = "test123password"
|
||||
self.salt = os.urandom(16)
|
||||
self.salt_base64 = base64.b64encode(self.salt).decode()
|
||||
|
||||
password_hash = hash_password(self.password, self.salt)
|
||||
self.password_hash_base64 = base64.b64encode(password_hash).decode()
|
||||
|
||||
def test_should_verify_correct_password(self):
|
||||
"""Test correct password verification"""
|
||||
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
|
||||
assert result is True
|
||||
|
||||
def test_should_reject_wrong_password(self):
|
||||
"""Test rejection of incorrect passwords"""
|
||||
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
|
||||
def test_should_handle_invalid_base64(self):
|
||||
"""Test handling of invalid base64 data"""
|
||||
# Invalid base64 hash
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, "invalid_base64!", self.salt_base64)
|
||||
|
||||
# Invalid base64 salt
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
|
||||
|
||||
def test_should_be_case_sensitive(self):
|
||||
"""Test password case sensitivity"""
|
||||
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
@ -0,0 +1,465 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||
"""Create a mock LLMUsage with all required fields"""
|
||||
return LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1"),
|
||||
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1"),
|
||||
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
currency="USD",
|
||||
latency=1.5,
|
||||
)
|
||||
|
||||
|
||||
def get_model_entity(provider: str, model_name: str, support_structure_output: bool = False) -> AIModelEntity:
|
||||
"""Create a mock AIModelEntity for testing"""
|
||||
model_schema = MagicMock()
|
||||
model_schema.model = model_name
|
||||
model_schema.provider = provider
|
||||
model_schema.model_type = ModelType.LLM
|
||||
model_schema.model_provider = provider
|
||||
model_schema.model_name = model_name
|
||||
model_schema.support_structure_output = support_structure_output
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
return model_schema
|
||||
|
||||
|
||||
def get_model_instance() -> MagicMock:
|
||||
"""Create a mock ModelInstance for testing"""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.provider = "openai"
|
||||
mock_instance.credentials = {}
|
||||
return mock_instance
|
||||
|
||||
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"answer": "test response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
{
|
||||
"name": "error_non_string_content_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content=None), # Non-string content
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": None,
|
||||
"should_raise": True,
|
||||
"expected_error": OutputParserError,
|
||||
},
|
||||
# Test case 7: JSON repair scenario
|
||||
{
|
||||
"name": "json_repair_scenario",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"'), # Invalid JSON - missing closing brace
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 8: Model with parameter rules for response format
|
||||
{
|
||||
"name": "model_with_parameter_rules",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"result": "success"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 9: Model without native support but with JSON response format rules
|
||||
{
|
||||
"name": "non_native_with_json_rules",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"output": "result"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
print(f"Running test case: {case['name']}")
|
||||
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||
|
||||
# Setup prompt messages
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Generate a response according to the schema."),
|
||||
]
|
||||
|
||||
if case["should_raise"]:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
||||
def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
}
|
||||
|
||||
# Setup for list parsing test
|
||||
model_schema = get_model_entity(
|
||||
testcase_list_with_dict["provider"],
|
||||
testcase_list_with_dict["model_name"],
|
||||
testcase_list_with_dict["support_structure_output"],
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=testcase_list_with_dict["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
stream=testcase_list_with_dict["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"thought": "reasoning process"}
|
||||
|
||||
|
||||
def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||
}
|
||||
|
||||
model_schema = get_model_entity(
|
||||
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gemini-pro",
|
||||
message=AssistantPromptMessage(content='{"result": "true"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test")]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=gemini_case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
stream=gemini_case["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
|
||||
# Verify model_instance.invoke_llm was called and check the schema preparation
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
|
||||
assert "json_schema" in call_args.kwargs["model_parameters"]
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,11 @@
|
||||
import { TracingProvider } from './type'
|
||||
|
||||
export const docURL = {
|
||||
[TracingProvider.arize]: 'https://docs.arize.com/arize',
|
||||
[TracingProvider.phoenix]: 'https://docs.arize.com/phoenix',
|
||||
[TracingProvider.langSmith]: 'https://docs.smith.langchain.com/',
|
||||
[TracingProvider.langfuse]: 'https://docs.langfuse.com',
|
||||
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions',
|
||||
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
|
||||
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
|
||||
}
|
||||
|
||||
@ -0,0 +1,145 @@
|
||||
import type { ReactElement } from 'react'
|
||||
import { cloneElement, useCallback } from 'react'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../base/portal-to-follow-elem'
|
||||
import { RiMoreLine } from '@remixicon/react'
|
||||
|
||||
export type Operation = {
|
||||
id: string; title: string; icon: ReactElement; onClick: () => void
|
||||
}
|
||||
|
||||
const AppOperations = ({ operations, gap }: {
|
||||
operations: Operation[]
|
||||
gap: number
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [visibleOpreations, setVisibleOperations] = useState<Operation[]>([])
|
||||
const [moreOperations, setMoreOperations] = useState<Operation[]>([])
|
||||
const [showMore, setShowMore] = useState(false)
|
||||
const navRef = useRef<HTMLDivElement>(null)
|
||||
const handleTriggerMore = useCallback(() => {
|
||||
setShowMore(true)
|
||||
}, [setShowMore])
|
||||
|
||||
useEffect(() => {
|
||||
const moreElement = document.getElementById('more')
|
||||
const navElement = document.getElementById('nav')
|
||||
let width = 0
|
||||
const containerWidth = navElement?.clientWidth ?? 0
|
||||
const moreWidth = moreElement?.clientWidth ?? 0
|
||||
|
||||
if (containerWidth === 0 || moreWidth === 0) return
|
||||
|
||||
const updatedEntries: Record<string, boolean> = operations.reduce((pre, cur) => {
|
||||
pre[cur.id] = false
|
||||
return pre
|
||||
}, {} as Record<string, boolean>)
|
||||
const childrens = Array.from(navRef.current!.children).slice(0, -1)
|
||||
for (let i = 0; i < childrens.length; i++) {
|
||||
const child: any = childrens[i]
|
||||
const id = child.dataset.targetid
|
||||
if (!id) break
|
||||
const childWidth = child.clientWidth
|
||||
|
||||
if (width + gap + childWidth + moreWidth <= containerWidth) {
|
||||
updatedEntries[id] = true
|
||||
width += gap + childWidth
|
||||
}
|
||||
else {
|
||||
if (i === childrens.length - 1 && width + childWidth <= containerWidth)
|
||||
updatedEntries[id] = true
|
||||
else
|
||||
updatedEntries[id] = false
|
||||
break
|
||||
}
|
||||
}
|
||||
setVisibleOperations(operations.filter(item => updatedEntries[item.id]))
|
||||
setMoreOperations(operations.filter(item => !updatedEntries[item.id]))
|
||||
}, [operations, gap])
|
||||
|
||||
return (
|
||||
<>
|
||||
{!visibleOpreations.length && <div
|
||||
id="nav"
|
||||
ref={navRef}
|
||||
className="flex h-0 items-center self-stretch overflow-hidden"
|
||||
style={{ gap }}
|
||||
>
|
||||
{operations.map((operation, index) =>
|
||||
<Button
|
||||
key={index}
|
||||
data-targetid={operation.id}
|
||||
size={'small'}
|
||||
variant={'secondary'}
|
||||
className="gap-[1px]">
|
||||
{cloneElement(operation.icon, { className: 'h-3.5 w-3.5 text-components-button-secondary-text' })}
|
||||
<span className="system-xs-medium text-components-button-secondary-text">
|
||||
{operation.title}
|
||||
</span>
|
||||
</Button>,
|
||||
)}
|
||||
<Button
|
||||
id="more"
|
||||
size={'small'}
|
||||
variant={'secondary'}
|
||||
className="gap-[1px]"
|
||||
>
|
||||
<RiMoreLine className="h-3.5 w-3.5 text-components-button-secondary-text" />
|
||||
<span className="system-xs-medium text-components-button-secondary-text">
|
||||
{t('common.operation.more')}
|
||||
</span>
|
||||
</Button>
|
||||
</div>}
|
||||
<div className="flex items-center self-stretch overflow-hidden" style={{ gap }}>
|
||||
{visibleOpreations.map(operation =>
|
||||
<Button
|
||||
key={operation.id}
|
||||
data-targetid={operation.id}
|
||||
size={'small'}
|
||||
variant={'secondary'}
|
||||
className="gap-[1px]"
|
||||
onClick={operation.onClick}>
|
||||
{cloneElement(operation.icon, { className: 'h-3.5 w-3.5 text-components-button-secondary-text' })}
|
||||
<span className="system-xs-medium text-components-button-secondary-text">
|
||||
{operation.title}
|
||||
</span>
|
||||
</Button>,
|
||||
)}
|
||||
{visibleOpreations.length < operations.length && <PortalToFollowElem
|
||||
open={showMore}
|
||||
onOpenChange={setShowMore}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
}}>
|
||||
<PortalToFollowElemTrigger onClick={handleTriggerMore}>
|
||||
<Button
|
||||
size={'small'}
|
||||
variant={'secondary'}
|
||||
className='gap-[1px]'
|
||||
>
|
||||
<RiMoreLine className='h-3.5 w-3.5 text-components-button-secondary-text' />
|
||||
<span className='system-xs-medium text-components-button-secondary-text'>{t('common.operation.more')}</span>
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[21]'>
|
||||
<div className='flex min-w-[264px] flex-col rounded-[12px] border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg backdrop-blur-[5px]'>
|
||||
{moreOperations.map(item => <div
|
||||
key={item.id}
|
||||
className='flex h-8 cursor-pointer items-center gap-x-1 rounded-lg p-1.5 hover:bg-state-base-hover'
|
||||
onClick={item.onClick}
|
||||
>
|
||||
{cloneElement(item.icon, { className: 'h-4 w-4 text-text-tertiary' })}
|
||||
<span className='system-md-regular text-text-secondary'>{item.title}</span>
|
||||
</div>)}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default AppOperations
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue