Fix SQL injection vulnerabilities and character encoding issues

- Enhanced string escaping for SQL safety (backslashes, newlines, tabs)
- Added safe JSON formatting with ensure_ascii=True
- Implemented safe doc_id validation (alphanumeric + hyphens/underscores only)
- Protected all user input: document content, metadata, IDs, search queries
- Fixed potential SQL syntax errors from special characters in document content

This addresses "Syntax error at or near 'files'" errors that occur when
document content or metadata contains special characters that break SQL syntax.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent 9c2bf2b30f
commit f116407045

@ -317,11 +317,12 @@ class ClickzettaVector(BaseVector):
# Prepare batch insert # Prepare batch insert
values = [] values = []
for doc, embedding in zip(batch_docs, batch_embeddings): for doc, embedding in zip(batch_docs, batch_embeddings):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) doc_id = self._safe_doc_id(doc.metadata.get("doc_id", str(uuid.uuid4())))
# For JSON column in Clickzetta, use JSON 'json_string' format # For JSON column in Clickzetta, use safe JSON formatting
metadata_json = json.dumps(doc.metadata).replace("'", "''") # Escape single quotes metadata_json = self._escape_json_string(doc.metadata)
embedding_str = self._format_vector(embedding) embedding_str = self._format_vector(embedding)
values.append(f"('{doc_id}', '{self._escape_string(doc.page_content)}', " escaped_content = self._escape_string(doc.page_content)
values.append(f"('{doc_id}', '{escaped_content}', "
f"JSON '{metadata_json}', {embedding_str})") f"JSON '{metadata_json}', {embedding_str})")
# Use regular INSERT - primary key will handle duplicates # Use regular INSERT - primary key will handle duplicates
@ -337,9 +338,10 @@ class ClickzettaVector(BaseVector):
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
"""Check if a document exists by ID.""" """Check if a document exists by ID."""
safe_id = self._safe_doc_id(id)
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
cursor.execute( cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema}.{self._table_name} WHERE id = '{id}'" f"SELECT COUNT(*) FROM {self._config.schema}.{self._table_name} WHERE id = '{safe_id}'"
) )
result = cursor.fetchone() result = cursor.fetchone()
return result[0] > 0 if result else False return result[0] > 0 if result else False
@ -359,7 +361,8 @@ class ClickzettaVector(BaseVector):
def _delete_by_ids_impl(self, ids: list[str]) -> None: def _delete_by_ids_impl(self, ids: list[str]) -> None:
"""Implementation of delete by IDs (executed in write worker thread).""" """Implementation of delete by IDs (executed in write worker thread)."""
ids_str = ",".join(f"'{id}'" for id in ids) safe_ids = [self._safe_doc_id(id) for id in ids]
ids_str = ",".join(f"'{id}'" for id in safe_ids)
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
cursor.execute( cursor.execute(
f"DELETE FROM {self._config.schema}.{self._table_name} WHERE id IN ({ids_str})" f"DELETE FROM {self._config.schema}.{self._table_name} WHERE id IN ({ids_str})"
@ -377,11 +380,14 @@ class ClickzettaVector(BaseVector):
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread).""" """Implementation of delete by metadata field (executed in write worker thread)."""
# Safely escape the key and value
safe_key = self._escape_string(key)
safe_value = self._escape_string(value)
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
# Using JSON path to filter # Using JSON path to filter
cursor.execute( cursor.execute(
f"DELETE FROM {self._config.schema}.{self._table_name} " f"DELETE FROM {self._config.schema}.{self._table_name} "
f"WHERE {Field.METADATA_KEY.value}->>'$.{key}' = '{value}'" f"WHERE {Field.METADATA_KEY.value}->>'$.{safe_key}' = '{safe_value}'"
) )
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
@ -393,7 +399,8 @@ class ClickzettaVector(BaseVector):
# Build filter clause # Build filter clause
filter_clauses = [] filter_clauses = []
if document_ids_filter: if document_ids_filter:
doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})")
# Add distance threshold based on distance function # Add distance threshold based on distance function
@ -457,7 +464,8 @@ class ClickzettaVector(BaseVector):
# Build filter clause # Build filter clause
filter_clauses = [] filter_clauses = []
if document_ids_filter: if document_ids_filter:
doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})")
# Use match_all function for full-text search # Use match_all function for full-text search
@ -501,7 +509,8 @@ class ClickzettaVector(BaseVector):
# Build filter clause # Build filter clause
filter_clauses = [] filter_clauses = []
if document_ids_filter: if document_ids_filter:
doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})")
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{self._escape_string(query)}%'") filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{self._escape_string(query)}%'")
@ -533,8 +542,17 @@ class ClickzettaVector(BaseVector):
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema}.{self._table_name}") cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema}.{self._table_name}")
def _escape_string(self, s: str) -> str: def _escape_string(self, s: str) -> str:
"""Escape single quotes in strings for SQL.""" """Escape single quotes and other special characters for SQL."""
return s.replace("'", "''") if s is None:
return ""
# Replace single quotes and other potentially problematic characters
s = str(s)
s = s.replace("'", "''") # Escape single quotes
s = s.replace("\\", "\\\\") # Escape backslashes
s = s.replace("\n", "\\n") # Escape newlines
s = s.replace("\r", "\\r") # Escape carriage returns
s = s.replace("\t", "\\t") # Escape tabs
return s
def _format_vector(self, vector: list[float]) -> str: def _format_vector(self, vector: list[float]) -> str:
"""Safely format vector for SQL, handling special float values.""" """Safely format vector for SQL, handling special float values."""
@ -555,6 +573,28 @@ class ClickzettaVector(BaseVector):
safe_values.append("0.0") safe_values.append("0.0")
return f"VECTOR({','.join(safe_values)})" return f"VECTOR({','.join(safe_values)})"
def _escape_json_string(self, obj: dict) -> str:
"""Safely format JSON for SQL, escaping special characters."""
try:
json_str = json.dumps(obj, ensure_ascii=True)
# Escape single quotes for SQL
return json_str.replace("'", "''")
except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize metadata to JSON: {e}")
return "{}"
def _safe_doc_id(self, doc_id: str) -> str:
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
if not doc_id:
return str(uuid.uuid4())
# Remove or replace potentially problematic characters
safe_id = str(doc_id)
# Only allow alphanumeric, hyphens, underscores
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_')
if not safe_id: # If all characters were removed
return str(uuid.uuid4())
return safe_id[:255] # Limit length
class ClickzettaVectorFactory(AbstractVectorFactory): class ClickzettaVectorFactory(AbstractVectorFactory):
"""Factory for creating Clickzetta vector instances.""" """Factory for creating Clickzetta vector instances."""

Loading…
Cancel
Save