@ -3,11 +3,14 @@ import logging
import queue
import threading
import uuid
from typing import Any , Optional
from typing import Any , Optional , TYPE_CHECKING
import clickzetta # type: ignore
from pydantic import BaseModel , model_validator
if TYPE_CHECKING :
from clickzetta import Connection
from configs import dify_config
from core . rag . datasource . vdb . field import Field
from core . rag . datasource . vdb . vector_base import BaseVector
@ -79,7 +82,7 @@ class ClickzettaVector(BaseVector):
super ( ) . __init__ ( collection_name )
self . _config = config
self . _table_name = collection_name . replace ( " - " , " _ " ) . lower ( ) # Ensure valid table name
self . _connection = None
self . _connection : Optional [ " Connection " ] = None
self . _init_connection ( )
self . _init_write_queue ( )
@ -96,6 +99,7 @@ class ClickzettaVector(BaseVector):
)
# Set session parameters for better string handling
if self . _connection is not None :
with self . _connection . cursor ( ) as cursor :
# Use quote mode for string literal escaping to handle quotes better
cursor . execute ( " SET cz.sql.string.literal.escape.mode = ' quote ' " )
@ -117,6 +121,7 @@ class ClickzettaVector(BaseVector):
while not cls . _shutdown :
try :
# Get task from queue with timeout
if cls . _write_queue is not None :
task = cls . _write_queue . get ( timeout = 1 )
if task is None : # Shutdown signal
break
@ -131,6 +136,8 @@ class ClickzettaVector(BaseVector):
result_queue . put ( ( False , e ) )
finally :
cls . _write_queue . task_done ( )
else :
break
except queue . Empty :
continue
except Exception as e :
@ -141,7 +148,7 @@ class ClickzettaVector(BaseVector):
if ClickzettaVector . _write_queue is None :
raise RuntimeError ( " Write queue not initialized " )
result_queue = queue . Queue ( )
result_queue : queue . Queue [ tuple [ bool , Any ] ] = queue . Queue ( )
ClickzettaVector . _write_queue . put ( ( func , args , kwargs , result_queue ) )
# Wait for result
@ -154,10 +161,17 @@ class ClickzettaVector(BaseVector):
""" Return the vector database type. """
return " clickzetta "
def _ensure_connection ( self ) - > " Connection " :
""" Ensure connection is available and return it. """
if self . _connection is None :
raise RuntimeError ( " Database connection not initialized " )
return self . _connection
def _table_exists ( self ) - > bool :
""" Check if the table exists. """
try :
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute ( f " DESC { self . _config . schema_name } . { self . _table_name } " )
return True
except Exception as e :
@ -197,7 +211,8 @@ class ClickzettaVector(BaseVector):
) COMMENT ' Dify RAG knowledge base vector storage table for document embeddings and content '
"""
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute ( create_table_sql )
logger . info ( f " Created table { self . _config . schema_name } . { self . _table_name } " )
@ -369,7 +384,8 @@ class ClickzettaVector(BaseVector):
f " VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR( { vector_dimension } ))) "
)
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
try :
cursor . executemany ( insert_sql , data_rows )
logger . info (
@ -385,7 +401,8 @@ class ClickzettaVector(BaseVector):
def text_exists ( self , id : str ) - > bool :
""" Check if a document exists by ID. """
safe_id = self . _safe_doc_id ( id )
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute (
f " SELECT COUNT(*) FROM { self . _config . schema_name } . { self . _table_name } WHERE id = ? " ,
[ safe_id ]
@ -413,7 +430,8 @@ class ClickzettaVector(BaseVector):
id_list = " , " . join ( f " ' { id } ' " for id in safe_ids )
sql = f " DELETE FROM { self . _config . schema_name } . { self . _table_name } WHERE id IN ( { id_list } ) "
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute ( sql )
def delete_by_metadata_field ( self , key : str , value : str ) - > None :
@ -428,7 +446,8 @@ class ClickzettaVector(BaseVector):
def _delete_by_metadata_field_impl ( self , key : str , value : str ) - > None :
""" Implementation of delete by metadata field (executed in write worker thread). """
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
# Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility
@ -488,7 +507,8 @@ class ClickzettaVector(BaseVector):
"""
documents = [ ]
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute ( search_sql )
results = cursor . fetchall ( )
@ -572,7 +592,8 @@ class ClickzettaVector(BaseVector):
"""
documents = [ ]
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
try :
cursor . execute ( search_sql )
results = cursor . fetchall ( )
@ -649,7 +670,8 @@ class ClickzettaVector(BaseVector):
"""
documents = [ ]
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute ( search_sql )
results = cursor . fetchall ( )
@ -689,7 +711,8 @@ class ClickzettaVector(BaseVector):
def delete ( self ) - > None :
""" Delete the entire collection. """
with self . _connection . cursor ( ) as cursor :
connection = self . _ensure_connection ( )
with connection . cursor ( ) as cursor :
cursor . execute ( f " DROP TABLE IF EXISTS { self . _config . schema_name } . { self . _table_name } " )
@ -718,13 +741,13 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
""" Initialize a Clickzetta vector instance. """
# Get configuration from environment variables or dataset config
config = ClickzettaConfig (
username = dify_config . CLICKZETTA_USERNAME ,
password = dify_config . CLICKZETTA_PASSWORD ,
instance = dify_config . CLICKZETTA_INSTANCE ,
service = dify_config . CLICKZETTA_SERVICE ,
workspace = dify_config . CLICKZETTA_WORKSPACE ,
vcluster = dify_config . CLICKZETTA_VCLUSTER ,
schema_name = dify_config . CLICKZETTA_SCHEMA ,
username = dify_config . CLICKZETTA_USERNAME or " " ,
password = dify_config . CLICKZETTA_PASSWORD or " " ,
instance = dify_config . CLICKZETTA_INSTANCE or " " ,
service = dify_config . CLICKZETTA_SERVICE or " api.clickzetta.com " ,
workspace = dify_config . CLICKZETTA_WORKSPACE or " quick_start " ,
vcluster = dify_config . CLICKZETTA_VCLUSTER or " default_ap " ,
schema_name = dify_config . CLICKZETTA_SCHEMA or " dify " ,
batch_size = dify_config . CLICKZETTA_BATCH_SIZE or 100 ,
enable_inverted_index = dify_config . CLICKZETTA_ENABLE_INVERTED_INDEX or True ,
analyzer_type = dify_config . CLICKZETTA_ANALYZER_TYPE or " chinese " ,