Merge 18230d12f9 into bd43ca6275
commit
37c534df21
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,65 @@
|
||||
"""ClickZetta Volume Storage Configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickZettaVolumeStorageConfig(BaseSettings):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
|
||||
description="Username for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
|
||||
description="Password for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
|
||||
description="ClickZetta instance identifier",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SERVICE: str = Field(
|
||||
description="ClickZetta service endpoint",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
|
||||
description="ClickZetta workspace name",
|
||||
default="quick_start",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
|
||||
description="ClickZetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SCHEMA: str = Field(
|
||||
description="ClickZetta schema name",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TYPE: str = Field(
|
||||
description="ClickZetta volume type (table|user|external)",
|
||||
default="user",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
|
||||
description="ClickZetta volume name for external volumes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
|
||||
description="Prefix for ClickZetta volume table names",
|
||||
default="dataset_",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
|
||||
description="Directory prefix for User Volume to organize Dify files",
|
||||
default="dify_km",
|
||||
)
|
||||
@ -0,0 +1,69 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
CLICKZETTA_USERNAME: Optional[str] = Field(
|
||||
description="Username for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_INSTANCE: Optional[str] = Field(
|
||||
description="Clickzetta Lakehouse instance ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_SERVICE: Optional[str] = Field(
|
||||
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_WORKSPACE: Optional[str] = Field(
|
||||
description="Clickzetta workspace name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
CLICKZETTA_VCLUSTER: Optional[str] = Field(
|
||||
description="Clickzetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_SCHEMA: Optional[str] = Field(
|
||||
description="Database schema name in Clickzetta",
|
||||
default="public",
|
||||
)
|
||||
|
||||
CLICKZETTA_BATCH_SIZE: Optional[int] = Field(
|
||||
description="Batch size for bulk insert operations",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field(
|
||||
description="Enable inverted index for full-text search capabilities",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field(
|
||||
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
|
||||
default="chinese",
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_MODE: Optional[str] = Field(
|
||||
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
|
||||
default="smart",
|
||||
)
|
||||
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
|
||||
description="Distance function for vector similarity: l2_distance or cosine_distance",
|
||||
default="cosine_distance",
|
||||
)
|
||||
@ -0,0 +1,190 @@
|
||||
# Clickzetta Vector Database Integration
|
||||
|
||||
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
|
||||
- **Vector Search**: Efficient similarity search using HNSW algorithm
|
||||
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
|
||||
- **Hybrid Search**: Combine vector similarity and full-text search for better results
|
||||
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
|
||||
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
All seven configuration parameters are required:
|
||||
|
||||
```bash
|
||||
# Authentication
|
||||
CLICKZETTA_USERNAME=your_username
|
||||
CLICKZETTA_PASSWORD=your_password
|
||||
|
||||
# Instance configuration
|
||||
CLICKZETTA_INSTANCE=your_instance_id
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=your_workspace
|
||||
CLICKZETTA_VCLUSTER=your_vcluster
|
||||
CLICKZETTA_SCHEMA=your_schema
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Batch processing
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
|
||||
# Full-text search configuration
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
|
||||
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
|
||||
|
||||
# Vector search configuration
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Set Clickzetta as the Vector Store
|
||||
|
||||
In your Dify configuration, set:
|
||||
|
||||
```bash
|
||||
VECTOR_STORE=clickzetta
|
||||
```
|
||||
|
||||
### 2. Table Structure
|
||||
|
||||
Clickzetta will automatically create tables with the following structure:
|
||||
|
||||
```sql
|
||||
CREATE TABLE <collection_name> (
|
||||
id STRING NOT NULL,
|
||||
content STRING NOT NULL,
|
||||
metadata JSON,
|
||||
vector VECTOR(FLOAT, <dimension>) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
-- Vector index for similarity search
|
||||
CREATE VECTOR INDEX idx_<collection_name>_vec
|
||||
ON TABLE <schema>.<collection_name>(vector)
|
||||
PROPERTIES (
|
||||
"distance.function" = "cosine_distance",
|
||||
"scalar.type" = "f32"
|
||||
);
|
||||
|
||||
-- Inverted index for full-text search (if enabled)
|
||||
CREATE INVERTED INDEX idx_<collection_name>_text
|
||||
ON <schema>.<collection_name>(content)
|
||||
PROPERTIES (
|
||||
"analyzer" = "chinese",
|
||||
"mode" = "smart"
|
||||
);
|
||||
```
|
||||
|
||||
## Full-Text Search Capabilities
|
||||
|
||||
Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
2. **english**: Designed for English text
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
3. **chinese**: Chinese text tokenizer
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
4. **unicode**: Multi-language tokenizer based on Unicode
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
### Analyzer Modes
|
||||
|
||||
- **max_word**: Fine-grained tokenization (more tokens)
|
||||
- **smart**: Intelligent tokenization (balanced)
|
||||
|
||||
### Full-Text Search Functions
|
||||
|
||||
- `MATCH_ALL(column, query)`: All terms must be present
|
||||
- `MATCH_ANY(column, query)`: At least one term must be present
|
||||
- `MATCH_PHRASE(column, query)`: Exact phrase matching
|
||||
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
|
||||
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
2. **Use appropriate distance functions**:
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
2. **Combine with vector search**:
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
2. Check network connectivity to Clickzetta service
|
||||
3. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
2. Check if vector index is being used:
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
2. Check analyzer configuration matches your content language
|
||||
3. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
2. Full-text search relevance scores are not provided by Clickzetta
|
||||
3. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
4. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
5. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md)
|
||||
- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md)
|
||||
- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/)
|
||||
@ -0,0 +1 @@
|
||||
# Clickzetta Vector Database Integration for Dify
|
||||
@ -0,0 +1,762 @@
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clickzetta import Connection
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ClickZetta Lakehouse Vector Database Configuration
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Clickzetta connection.
|
||||
"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
instance: str
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema
|
||||
# Advanced settings
|
||||
batch_size: int = 20 # Reduced batch size to avoid large SQL statements
|
||||
enable_inverted_index: bool = True # Enable inverted index for full-text search
|
||||
analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode
|
||||
analyzer_mode: str = "smart" # Analyzer mode: max_word, smart
|
||||
vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
"""
|
||||
if not values.get("username"):
|
||||
raise ValueError("config CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("config CLICKZETTA_INSTANCE is required")
|
||||
if not values.get("service"):
|
||||
raise ValueError("config CLICKZETTA_SERVICE is required")
|
||||
if not values.get("workspace"):
|
||||
raise ValueError("config CLICKZETTA_WORKSPACE is required")
|
||||
if not values.get("vcluster"):
|
||||
raise ValueError("config CLICKZETTA_VCLUSTER is required")
|
||||
if not values.get("schema_name"):
|
||||
raise ValueError("config CLICKZETTA_SCHEMA is required")
|
||||
return values
|
||||
|
||||
|
||||
class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
Clickzetta vector storage implementation.
|
||||
"""
|
||||
|
||||
# Class-level write queue and lock for serializing writes
|
||||
_write_queue: Optional[queue.Queue] = None
|
||||
_write_thread: Optional[threading.Thread] = None
|
||||
_write_lock = threading.Lock()
|
||||
_shutdown = False
|
||||
|
||||
def __init__(self, collection_name: str, config: ClickzettaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
|
||||
self._connection: Optional["Connection"] = None
|
||||
self._init_connection()
|
||||
self._init_write_queue()
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize Clickzetta connection."""
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name
|
||||
)
|
||||
|
||||
# Set session parameters for better string handling
|
||||
if self._connection is not None:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use quote mode for string literal escaping to handle quotes better
|
||||
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
|
||||
logger.info("Set string literal escape mode to 'quote' for better quote handling")
|
||||
|
||||
@classmethod
|
||||
def _init_write_queue(cls):
|
||||
"""Initialize the write queue and worker thread."""
|
||||
with cls._write_lock:
|
||||
if cls._write_queue is None:
|
||||
cls._write_queue = queue.Queue()
|
||||
cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True)
|
||||
cls._write_thread.start()
|
||||
logger.info("Started Clickzetta write worker thread")
|
||||
|
||||
@classmethod
|
||||
def _write_worker(cls):
|
||||
"""Worker thread that processes write tasks sequentially."""
|
||||
while not cls._shutdown:
|
||||
try:
|
||||
# Get task from queue with timeout
|
||||
if cls._write_queue is not None:
|
||||
task = cls._write_queue.get(timeout=1)
|
||||
if task is None: # Shutdown signal
|
||||
break
|
||||
|
||||
# Execute the write task
|
||||
func, args, kwargs, result_queue = task
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result_queue.put((True, result))
|
||||
except Exception as e:
|
||||
logger.exception("Write task failed")
|
||||
result_queue.put((False, e))
|
||||
finally:
|
||||
cls._write_queue.task_done()
|
||||
else:
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception("Write worker error")
|
||||
|
||||
def _execute_write(self, func, *args, **kwargs):
|
||||
"""Execute a write operation through the queue."""
|
||||
if ClickzettaVector._write_queue is None:
|
||||
raise RuntimeError("Write queue not initialized")
|
||||
|
||||
result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue()
|
||||
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
|
||||
|
||||
# Wait for result
|
||||
success, result = result_queue.get()
|
||||
if not success:
|
||||
raise result
|
||||
return result
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Return the vector database type."""
|
||||
return "clickzetta"
|
||||
|
||||
def _ensure_connection(self) -> "Connection":
|
||||
"""Ensure connection is available and return it."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Database connection not initialized")
|
||||
return self._connection
|
||||
|
||||
def _table_exists(self) -> bool:
|
||||
"""Check if the table exists."""
|
||||
try:
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
if "table or view not found" in str(e).lower():
|
||||
return False
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create the collection and add initial documents."""
|
||||
# Execute table creation through write queue to avoid concurrent conflicts
|
||||
self._execute_write(self._create_table_and_indexes, embeddings)
|
||||
|
||||
# Add initial texts
|
||||
if texts:
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def _create_table_and_indexes(self, embeddings: list[list[float]]):
|
||||
"""Create table and indexes (executed in write worker thread)."""
|
||||
# Check if table already exists to avoid unnecessary index creation
|
||||
if self._table_exists():
|
||||
logger.info(f"Table {self._config.schema_name}.{self._table_name} already exists, skipping creation")
|
||||
return
|
||||
|
||||
# Create table with vector and metadata columns
|
||||
dimension = len(embeddings[0]) if embeddings else 768
|
||||
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
|
||||
id STRING NOT NULL COMMENT 'Unique document identifier',
|
||||
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
'High-dimensional embedding vector for semantic similarity search',
|
||||
PRIMARY KEY (id)
|
||||
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
|
||||
"""
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(create_table_sql)
|
||||
logger.info(f"Created table {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
# Create vector index
|
||||
self._create_vector_index(cursor)
|
||||
|
||||
# Create inverted index for full-text search if enabled
|
||||
if self._config.enable_inverted_index:
|
||||
self._create_inverted_index(cursor)
|
||||
|
||||
def _create_vector_index(self, cursor):
|
||||
"""Create HNSW vector index for similarity search."""
|
||||
# Use a fixed index name based on table and column name
|
||||
index_name = f"idx_{self._table_name}_vector"
|
||||
|
||||
# First check if an index already exists on this column
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
# Check if vector index already exists on the embedding column
|
||||
if Field.VECTOR.value in str(idx).lower():
|
||||
logger.info(f"Vector index already exists on column {Field.VECTOR.value}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check existing indexes: {e}")
|
||||
|
||||
index_sql = f"""
|
||||
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
|
||||
PROPERTIES (
|
||||
"distance.function" = "{self._config.vector_distance_function}",
|
||||
"scalar.type" = "f32",
|
||||
"m" = "16",
|
||||
"ef.construction" = "128"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
cursor.execute(index_sql)
|
||||
logger.info(f"Created vector index: {index_name}")
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if ("already exists" in error_msg or
|
||||
"already has index" in error_msg or
|
||||
"with the same type" in error_msg):
|
||||
logger.info(f"Vector index already exists: {e}")
|
||||
else:
|
||||
logger.exception("Failed to create vector index")
|
||||
raise
|
||||
|
||||
def _create_inverted_index(self, cursor):
|
||||
"""Create inverted index for full-text search."""
|
||||
# Use a fixed index name based on table name to avoid duplicates
|
||||
index_name = f"idx_{self._table_name}_text"
|
||||
|
||||
# Check if an inverted index already exists on this column
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
idx_str = str(idx).lower()
|
||||
# More precise check: look for inverted index specifically on the content column
|
||||
if ("inverted" in idx_str and
|
||||
Field.CONTENT_KEY.value.lower() in idx_str and
|
||||
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
|
||||
logger.info(f"Inverted index already exists on column {Field.CONTENT_KEY.value}: {idx}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check existing indexes: {e}")
|
||||
|
||||
index_sql = f"""
|
||||
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
|
||||
PROPERTIES (
|
||||
"analyzer" = "{self._config.analyzer_type}",
|
||||
"mode" = "{self._config.analyzer_mode}"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
cursor.execute(index_sql)
|
||||
logger.info(f"Created inverted index: {index_name}")
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
# Handle ClickZetta specific error messages
|
||||
if (("already exists" in error_msg or
|
||||
"already has index" in error_msg or
|
||||
"with the same type" in error_msg or
|
||||
"cannot create inverted index" in error_msg) and
|
||||
"already has index" in error_msg):
|
||||
logger.info(f"Inverted index already exists on column {Field.CONTENT_KEY.value}")
|
||||
# Try to get the existing index name for logging
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
|
||||
logger.info(f"Found existing inverted index: {idx}")
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.warning(f"Failed to create inverted index: {e}")
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i:i + batch_size]
|
||||
batch_embeddings = embeddings[i:i + batch_size]
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
|
||||
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
|
||||
batch_index: int, batch_size: int, total_batches: int):
|
||||
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
|
||||
if not batch_docs or not batch_embeddings:
|
||||
logger.warning("Empty batch provided, skipping insertion")
|
||||
return
|
||||
|
||||
if len(batch_docs) != len(batch_embeddings):
|
||||
logger.error(f"Mismatch between docs ({len(batch_docs)}) and embeddings ({len(batch_embeddings)})")
|
||||
return
|
||||
|
||||
# Prepare data for parameterized insertion
|
||||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata if doc.metadata else {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=True)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("JSON serialization failed, using empty dict")
|
||||
metadata_json = "{}"
|
||||
|
||||
content = doc.page_content or ""
|
||||
|
||||
# According to ClickZetta docs, vector should be formatted as array string
|
||||
# for external systems: '[1.0, 2.0, 3.0]'
|
||||
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
||||
data_rows.append([doc_id, content, metadata_json, vector_str])
|
||||
|
||||
# Check if we have any valid data to insert
|
||||
if not data_rows:
|
||||
logger.warning(f"No valid documents to insert in batch {batch_index // batch_size + 1}/{total_batches}")
|
||||
return
|
||||
|
||||
# Use parameterized INSERT with executemany for better performance and security
|
||||
# Cast JSON and VECTOR in SQL, pass raw data as parameters
|
||||
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
|
||||
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
|
||||
)
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
cursor.executemany(insert_sql, data_rows)
|
||||
logger.info(
|
||||
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
|
||||
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Parameterized SQL execution failed for {len(data_rows)} documents: {e}")
|
||||
logger.exception(f"SQL template: {insert_sql}")
|
||||
logger.exception(f"Sample data row: {data_rows[0] if data_rows else 'None'}")
|
||||
raise
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a document exists by ID."""
|
||||
safe_id = self._safe_doc_id(id)
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
|
||||
[safe_id]
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0 if result else False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""Delete documents by IDs."""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
# Check if table exists before attempting delete
|
||||
if not self._table_exists():
|
||||
logger.warning(f"Table {self._config.schema_name}.{self._table_name} does not exist, skipping delete")
|
||||
return
|
||||
|
||||
# Execute delete through write queue
|
||||
self._execute_write(self._delete_by_ids_impl, ids)
|
||||
|
||||
def _delete_by_ids_impl(self, ids: list[str]) -> None:
|
||||
"""Implementation of delete by IDs (executed in write worker thread)."""
|
||||
safe_ids = [self._safe_doc_id(id) for id in ids]
|
||||
# Create properly escaped string literals for SQL
|
||||
id_list = ",".join(f"'{id}'" for id in safe_ids)
|
||||
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field."""
|
||||
# Check if table exists before attempting delete
|
||||
if not self._table_exists():
|
||||
logger.warning(f"Table {self._config.schema_name}.{self._table_name} does not exist, skipping delete")
|
||||
return
|
||||
|
||||
# Execute delete through write queue
|
||||
self._execute_write(self._delete_by_metadata_field_impl, key, value)
|
||||
|
||||
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
|
||||
"""Implementation of delete by metadata field (executed in write worker thread)."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Using JSON path to filter with parameterized query
|
||||
# Note: JSON path requires literal key name, cannot be parameterized
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?")
|
||||
cursor.execute(sql, [value])
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Add distance threshold based on distance function
|
||||
vector_dimension = len(query_vector)
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
# For cosine distance, smaller is better (0 = identical, 2 = opposite)
|
||||
distance_func = "COSINE_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
|
||||
f"{query_vector_str}) < {2 - score_threshold}")
|
||||
else:
|
||||
# For L2 distance, smaller is better
|
||||
distance_func = "L2_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
|
||||
f"{query_vector_str}) < {score_threshold}")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
|
||||
|
||||
# Execute vector search query
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
|
||||
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(search_sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error(f"JSON parsing failed: {e}")
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add score based on distance
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
metadata["score"] = 1 - (row[3] / 2)
|
||||
else:
|
||||
metadata["score"] = 1 / (1 + row[3])
|
||||
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents using full-text search with inverted index."""
|
||||
if not self._config.enable_inverted_index:
|
||||
logger.warning("Full-text search is not enabled. Enable inverted index in config.")
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use match_all function for full-text search
|
||||
# match_all requires all terms to be present
|
||||
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
# Execute full-text search query
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
cursor.execute(search_sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error(f"JSON parsing failed: {e}")
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add a relevance score for full-text search
|
||||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
|
||||
return documents
|
||||
|
||||
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Fallback search using LIKE operator."""
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(search_sql)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error(f"JSON parsing failed: {e}")
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
metadata["score"] = 0.5 # Lower score for LIKE search
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
|
||||
def _format_vector_simple(self, vector: list[float]) -> str:
|
||||
"""Simple vector formatting for SQL queries."""
|
||||
return ','.join(map(str, vector))
|
||||
|
||||
def _safe_doc_id(self, doc_id: str) -> str:
|
||||
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
|
||||
if not doc_id:
|
||||
return str(uuid.uuid4())
|
||||
# Remove or replace potentially problematic characters
|
||||
safe_id = str(doc_id)
|
||||
# Only allow alphanumeric, hyphens, underscores
|
||||
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_')
|
||||
if not safe_id: # If all characters were removed
|
||||
return str(uuid.uuid4())
|
||||
return safe_id[:255] # Limit length
|
||||
|
||||
|
||||
|
||||
class ClickzettaVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating Clickzetta vector instances."""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
"""Initialize a Clickzetta vector instance."""
|
||||
# Get configuration from environment variables or dataset config
|
||||
config = ClickzettaConfig(
|
||||
username=dify_config.CLICKZETTA_USERNAME or "",
|
||||
password=dify_config.CLICKZETTA_PASSWORD or "",
|
||||
instance=dify_config.CLICKZETTA_INSTANCE or "",
|
||||
service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com",
|
||||
workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start",
|
||||
vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap",
|
||||
schema_name=dify_config.CLICKZETTA_SCHEMA or "dify",
|
||||
batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100,
|
||||
enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True,
|
||||
analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",
|
||||
analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart",
|
||||
vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance",
|
||||
)
|
||||
|
||||
# Use dataset collection name as table name
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
|
||||
|
||||
return ClickzettaVector(collection_name=collection_name, config=config)
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
@ -0,0 +1,168 @@
|
||||
"""Integration tests for ClickZetta Volume Storage."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
|
||||
class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
"""Test cases for ClickZetta Volume Storage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.config = ClickZettaVolumeConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
volume_type="table",
|
||||
table_prefix="test_dataset_",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_user_volume_operations(self):
|
||||
"""Test basic operations with User Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "user"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations
|
||||
test_filename = "test_file.txt"
|
||||
test_content = b"Hello, ClickZetta Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test streaming
|
||||
stream_content = b""
|
||||
for chunk in storage.load_stream(test_filename):
|
||||
stream_content += chunk
|
||||
assert stream_content == test_content
|
||||
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
with open(temp_file.name, "rb") as f:
|
||||
downloaded_content = f.read()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
files = storage.scan("", files=True, directories=False)
|
||||
assert test_filename in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_table_volume_operations(self):
|
||||
"""Test basic operations with Table Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "table"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations with dataset_id
|
||||
dataset_id = "12345"
|
||||
test_filename = f"{dataset_id}/test_file.txt"
|
||||
test_content = b"Hello, Table Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test scan for dataset
|
||||
files = storage.scan(dataset_id, files=True, directories=False)
|
||||
assert "test_file.txt" in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="", # Empty username should fail
|
||||
password="pass",
|
||||
instance="instance",
|
||||
)
|
||||
|
||||
# Test invalid volume type
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
|
||||
|
||||
# Test external volume without volume_name
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="user",
|
||||
password="pass",
|
||||
instance="instance",
|
||||
volume_type="external",
|
||||
# Missing volume_name
|
||||
)
|
||||
|
||||
def test_volume_path_generation(self):
|
||||
"""Test volume path generation for different types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume path
|
||||
path = storage._get_volume_path("test.txt", "12345")
|
||||
assert path == "test_dataset_12345/test.txt"
|
||||
|
||||
# Test path with existing dataset_id prefix
|
||||
path = storage._get_volume_path("12345/test.txt")
|
||||
assert path == "12345/test.txt"
|
||||
|
||||
# Test user volume
|
||||
storage._config.volume_type = "user"
|
||||
path = storage._get_volume_path("test.txt")
|
||||
assert path == "test.txt"
|
||||
|
||||
def test_sql_prefix_generation(self):
|
||||
"""Test SQL prefix generation for different volume types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume SQL prefix
|
||||
prefix = storage._get_volume_sql_prefix("12345")
|
||||
assert prefix == "TABLE VOLUME test_dataset_12345"
|
||||
|
||||
# Test user volume SQL prefix
|
||||
storage._config.volume_type = "user"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "USER VOLUME"
|
||||
|
||||
# Test external volume SQL prefix
|
||||
storage._config.volume_type = "external"
|
||||
storage._config.volume_name = "my_external_volume"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "VOLUME my_external_volume"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,25 @@
|
||||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
@ -0,0 +1,237 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(
|
||||
collection_name="test_collection_" + str(os.getpid()),
|
||||
config=config
|
||||
)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
try:
|
||||
vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
}
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(
|
||||
query_vector,
|
||||
top_k=5,
|
||||
document_ids_filter=["doc_0", "doc_1"]
|
||||
)
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(
|
||||
query_vector,
|
||||
top_k=10,
|
||||
score_threshold=0.5
|
||||
)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"}
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案",
|
||||
metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"}
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分",
|
||||
metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"}
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify")
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == 'dify']:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>"
|
||||
}
|
||||
|
||||
expected_metadata_fields = [
|
||||
"doc_id",
|
||||
"doc_hash",
|
||||
"document_id",
|
||||
"dataset_id"
|
||||
]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*50)
|
||||
print("Test Summary:")
|
||||
print("="*50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Loading…
Reference in New Issue