@ -317,11 +317,12 @@ class ClickzettaVector(BaseVector):
# Prepare batch insert
# Prepare batch insert
values = [ ]
values = [ ]
for doc , embedding in zip ( batch_docs , batch_embeddings ) :
for doc , embedding in zip ( batch_docs , batch_embeddings ) :
doc_id = doc . metadata . get ( " doc_id " , str ( uuid . uuid4 ( ) ) )
doc_id = self . _safe_doc_id ( doc . metadata . get ( " doc_id " , str ( uuid . uuid4 ( ) ) ) )
# For JSON column in Clickzetta, use JSON 'json_string' format
# For JSON column in Clickzetta, use safe JSON formatting
metadata_json = json . dumps ( doc . metadata ) . replace ( " ' " , " ' ' " ) # Escape single quotes
metadata_json = self . _escape_json_string ( doc . metadata )
embedding_str = self . _format_vector ( embedding )
embedding_str = self . _format_vector ( embedding )
values . append ( f " ( ' { doc_id } ' , ' { self . _escape_string ( doc . page_content ) } ' , "
escaped_content = self . _escape_string ( doc . page_content )
values . append ( f " ( ' { doc_id } ' , ' { escaped_content } ' , "
f " JSON ' { metadata_json } ' , { embedding_str } ) " )
f " JSON ' { metadata_json } ' , { embedding_str } ) " )
# Use regular INSERT - primary key will handle duplicates
# Use regular INSERT - primary key will handle duplicates
@ -337,9 +338,10 @@ class ClickzettaVector(BaseVector):
def text_exists ( self , id : str ) - > bool :
def text_exists ( self , id : str ) - > bool :
""" Check if a document exists by ID. """
""" Check if a document exists by ID. """
safe_id = self . _safe_doc_id ( id )
with self . _connection . cursor ( ) as cursor :
with self . _connection . cursor ( ) as cursor :
cursor . execute (
cursor . execute (
f " SELECT COUNT(*) FROM { self . _config . schema } . { self . _table_name } WHERE id = ' { id} ' "
f " SELECT COUNT(*) FROM { self . _config . schema } . { self . _table_name } WHERE id = ' { safe_ id} ' "
)
)
result = cursor . fetchone ( )
result = cursor . fetchone ( )
return result [ 0 ] > 0 if result else False
return result [ 0 ] > 0 if result else False
@ -359,7 +361,8 @@ class ClickzettaVector(BaseVector):
def _delete_by_ids_impl ( self , ids : list [ str ] ) - > None :
def _delete_by_ids_impl ( self , ids : list [ str ] ) - > None :
""" Implementation of delete by IDs (executed in write worker thread). """
""" Implementation of delete by IDs (executed in write worker thread). """
ids_str = " , " . join ( f " ' { id } ' " for id in ids )
safe_ids = [ self . _safe_doc_id ( id ) for id in ids ]
ids_str = " , " . join ( f " ' { id } ' " for id in safe_ids )
with self . _connection . cursor ( ) as cursor :
with self . _connection . cursor ( ) as cursor :
cursor . execute (
cursor . execute (
f " DELETE FROM { self . _config . schema } . { self . _table_name } WHERE id IN ( { ids_str } ) "
f " DELETE FROM { self . _config . schema } . { self . _table_name } WHERE id IN ( { ids_str } ) "
@ -377,11 +380,14 @@ class ClickzettaVector(BaseVector):
def _delete_by_metadata_field_impl ( self , key : str , value : str ) - > None :
def _delete_by_metadata_field_impl ( self , key : str , value : str ) - > None :
""" Implementation of delete by metadata field (executed in write worker thread). """
""" Implementation of delete by metadata field (executed in write worker thread). """
# Safely escape the key and value
safe_key = self . _escape_string ( key )
safe_value = self . _escape_string ( value )
with self . _connection . cursor ( ) as cursor :
with self . _connection . cursor ( ) as cursor :
# Using JSON path to filter
# Using JSON path to filter
cursor . execute (
cursor . execute (
f " DELETE FROM { self . _config . schema } . { self . _table_name } "
f " DELETE FROM { self . _config . schema } . { self . _table_name } "
f " WHERE { Field . METADATA_KEY . value } ->> ' $. { key} ' = ' { value} ' "
f " WHERE { Field . METADATA_KEY . value } ->> ' $. { safe_ key} ' = ' { safe_ value} ' "
)
)
def search_by_vector ( self , query_vector : list [ float ] , * * kwargs : Any ) - > list [ Document ] :
def search_by_vector ( self , query_vector : list [ float ] , * * kwargs : Any ) - > list [ Document ] :
@ -393,7 +399,8 @@ class ClickzettaVector(BaseVector):
# Build filter clause
# Build filter clause
filter_clauses = [ ]
filter_clauses = [ ]
if document_ids_filter :
if document_ids_filter :
doc_ids_str = " , " . join ( f " ' { id } ' " for id in document_ids_filter )
safe_doc_ids = [ self . _escape_string ( str ( id ) ) for id in document_ids_filter ]
doc_ids_str = " , " . join ( f " ' { id } ' " for id in safe_doc_ids )
filter_clauses . append ( f " { Field . METADATA_KEY . value } ->> ' $.document_id ' IN ( { doc_ids_str } ) " )
filter_clauses . append ( f " { Field . METADATA_KEY . value } ->> ' $.document_id ' IN ( { doc_ids_str } ) " )
# Add distance threshold based on distance function
# Add distance threshold based on distance function
@ -457,7 +464,8 @@ class ClickzettaVector(BaseVector):
# Build filter clause
# Build filter clause
filter_clauses = [ ]
filter_clauses = [ ]
if document_ids_filter :
if document_ids_filter :
doc_ids_str = " , " . join ( f " ' { id } ' " for id in document_ids_filter )
safe_doc_ids = [ self . _escape_string ( str ( id ) ) for id in document_ids_filter ]
doc_ids_str = " , " . join ( f " ' { id } ' " for id in safe_doc_ids )
filter_clauses . append ( f " { Field . METADATA_KEY . value } ->> ' $.document_id ' IN ( { doc_ids_str } ) " )
filter_clauses . append ( f " { Field . METADATA_KEY . value } ->> ' $.document_id ' IN ( { doc_ids_str } ) " )
# Use match_all function for full-text search
# Use match_all function for full-text search
@ -501,7 +509,8 @@ class ClickzettaVector(BaseVector):
# Build filter clause
# Build filter clause
filter_clauses = [ ]
filter_clauses = [ ]
if document_ids_filter :
if document_ids_filter :
doc_ids_str = " , " . join ( f " ' { id } ' " for id in document_ids_filter )
safe_doc_ids = [ self . _escape_string ( str ( id ) ) for id in document_ids_filter ]
doc_ids_str = " , " . join ( f " ' { id } ' " for id in safe_doc_ids )
filter_clauses . append ( f " { Field . METADATA_KEY . value } ->> ' $.document_id ' IN ( { doc_ids_str } ) " )
filter_clauses . append ( f " { Field . METADATA_KEY . value } ->> ' $.document_id ' IN ( { doc_ids_str } ) " )
filter_clauses . append ( f " { Field . CONTENT_KEY . value } LIKE ' % { self . _escape_string ( query ) } % ' " )
filter_clauses . append ( f " { Field . CONTENT_KEY . value } LIKE ' % { self . _escape_string ( query ) } % ' " )
@ -533,8 +542,17 @@ class ClickzettaVector(BaseVector):
cursor . execute ( f " DROP TABLE IF EXISTS { self . _config . schema } . { self . _table_name } " )
cursor . execute ( f " DROP TABLE IF EXISTS { self . _config . schema } . { self . _table_name } " )
def _escape_string ( self , s : str ) - > str :
def _escape_string ( self , s : str ) - > str :
""" Escape single quotes in strings for SQL. """
""" Escape single quotes and other special characters for SQL. """
return s . replace ( " ' " , " ' ' " )
if s is None :
return " "
# Replace single quotes and other potentially problematic characters
s = str ( s )
s = s . replace ( " ' " , " ' ' " ) # Escape single quotes
s = s . replace ( " \\ " , " \\ \\ " ) # Escape backslashes
s = s . replace ( " \n " , " \\ n " ) # Escape newlines
s = s . replace ( " \r " , " \\ r " ) # Escape carriage returns
s = s . replace ( " \t " , " \\ t " ) # Escape tabs
return s
def _format_vector ( self , vector : list [ float ] ) - > str :
def _format_vector ( self , vector : list [ float ] ) - > str :
""" Safely format vector for SQL, handling special float values. """
""" Safely format vector for SQL, handling special float values. """
@ -554,6 +572,28 @@ class ClickzettaVector(BaseVector):
else :
else :
safe_values . append ( " 0.0 " )
safe_values . append ( " 0.0 " )
return f " VECTOR( { ' , ' . join ( safe_values ) } ) "
return f " VECTOR( { ' , ' . join ( safe_values ) } ) "
def _escape_json_string ( self , obj : dict ) - > str :
""" Safely format JSON for SQL, escaping special characters. """
try :
json_str = json . dumps ( obj , ensure_ascii = True )
# Escape single quotes for SQL
return json_str . replace ( " ' " , " ' ' " )
except ( TypeError , ValueError ) as e :
logger . warning ( f " Failed to serialize metadata to JSON: { e } " )
return " {} "
def _safe_doc_id ( self , doc_id : str ) - > str :
""" Ensure doc_id is safe for SQL and doesn ' t contain special characters. """
if not doc_id :
return str ( uuid . uuid4 ( ) )
# Remove or replace potentially problematic characters
safe_id = str ( doc_id )
# Only allow alphanumeric, hyphens, underscores
safe_id = ' ' . join ( c for c in safe_id if c . isalnum ( ) or c in ' -_ ' )
if not safe_id : # If all characters were removed
return str ( uuid . uuid4 ( ) )
return safe_id [ : 255 ] # Limit length
class ClickzettaVectorFactory ( AbstractVectorFactory ) :
class ClickzettaVectorFactory ( AbstractVectorFactory ) :