From f1164070453e6c6a8d21db899363afe3dc707dec Mon Sep 17 00:00:00 2001 From: yunqiqiliang <132561395+yunqiqiliang@users.noreply.github.com> Date: Thu, 17 Jul 2025 17:37:49 +0800 Subject: [PATCH] Fix SQL injection vulnerabilities and character encoding issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Enhanced string escaping for SQL safety (backslashes, newlines, tabs) - Added safe JSON formatting with ensure_ascii=True - Implemented safe doc_id validation (alphanumeric + hyphens/underscores only) - Protected all user input: document content, metadata, IDs, search queries - Fixed potential SQL syntax errors from special characters in document content This addresses "Syntax error at or near 'files'" errors that occur when document content or metadata contains special characters that break SQL syntax. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../vdb/clickzetta/clickzetta_vector.py | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 3d786ff5f5..b484f0cb6b 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -317,11 +317,12 @@ class ClickzettaVector(BaseVector): # Prepare batch insert values = [] for doc, embedding in zip(batch_docs, batch_embeddings): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - # For JSON column in Clickzetta, use JSON 'json_string' format - metadata_json = json.dumps(doc.metadata).replace("'", "''") # Escape single quotes + doc_id = self._safe_doc_id(doc.metadata.get("doc_id", str(uuid.uuid4()))) + # For JSON column in Clickzetta, use safe JSON formatting + metadata_json = self._escape_json_string(doc.metadata) embedding_str = self._format_vector(embedding) - values.append(f"('{doc_id}', '{self._escape_string(doc.page_content)}', " + escaped_content = self._escape_string(doc.page_content) + values.append(f"('{doc_id}', '{escaped_content}', " f"JSON '{metadata_json}', {embedding_str})") # Use regular INSERT - primary key will handle duplicates @@ -337,9 +338,10 @@ class ClickzettaVector(BaseVector): def text_exists(self, id: str) -> bool: """Check if a document exists by ID.""" + safe_id = self._safe_doc_id(id) with self._connection.cursor() as cursor: cursor.execute( - f"SELECT COUNT(*) FROM {self._config.schema}.{self._table_name} WHERE id = '{id}'" + f"SELECT COUNT(*) FROM {self._config.schema}.{self._table_name} WHERE id = '{safe_id}'" ) result = cursor.fetchone() return result[0] > 0 if result else False @@ -359,7 +361,8 @@ class ClickzettaVector(BaseVector): def _delete_by_ids_impl(self, ids: list[str]) -> None: """Implementation of delete by IDs (executed in write worker thread).""" - ids_str = ",".join(f"'{id}'" for id in ids) + safe_ids = [self._safe_doc_id(id) for id in ids] + ids_str = ",".join(f"'{id}'" for id in safe_ids) with self._connection.cursor() as cursor: cursor.execute( f"DELETE FROM {self._config.schema}.{self._table_name} WHERE id IN ({ids_str})" @@ -377,11 +380,14 @@ class ClickzettaVector(BaseVector): def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: """Implementation of delete by metadata field (executed in write worker thread).""" + # Safely escape the key and value + safe_key = self._escape_string(key) + safe_value = self._escape_string(value) with self._connection.cursor() as cursor: # Using JSON path to filter cursor.execute( f"DELETE FROM {self._config.schema}.{self._table_name} " - f"WHERE {Field.METADATA_KEY.value}->>'$.{key}' = '{value}'" + f"WHERE {Field.METADATA_KEY.value}->>'$.{safe_key}' = '{safe_value}'" ) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: @@ -393,7 +399,8 @@ class ClickzettaVector(BaseVector): # Build filter clause filter_clauses = [] if document_ids_filter: - doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) + safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") # Add distance threshold based on distance function @@ -457,7 +464,8 @@ class ClickzettaVector(BaseVector): # Build filter clause filter_clauses = [] if document_ids_filter: - doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) + safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") # Use match_all function for full-text search @@ -501,7 +509,8 @@ class ClickzettaVector(BaseVector): # Build filter clause filter_clauses = [] if document_ids_filter: - doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) + safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{self._escape_string(query)}%'") @@ -533,8 +542,17 @@ class ClickzettaVector(BaseVector): cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema}.{self._table_name}") def _escape_string(self, s: str) -> str: - """Escape single quotes in strings for SQL.""" - return s.replace("'", "''") + """Escape single quotes and other special characters for SQL.""" + if s is None: + return "" + # Replace single quotes and other potentially problematic characters + s = str(s) + s = s.replace("'", "''") # Escape single quotes + s = s.replace("\\", "\\\\") # Escape backslashes + s = s.replace("\n", "\\n") # Escape newlines + s = s.replace("\r", "\\r") # Escape carriage returns + s = s.replace("\t", "\\t") # Escape tabs + return s def _format_vector(self, vector: list[float]) -> str: """Safely format vector for SQL, handling special float values.""" @@ -554,6 +572,28 @@ class ClickzettaVector(BaseVector): else: safe_values.append("0.0") return f"VECTOR({','.join(safe_values)})" + + def _escape_json_string(self, obj: dict) -> str: + """Safely format JSON for SQL, escaping special characters.""" + try: + json_str = json.dumps(obj, ensure_ascii=True) + # Escape single quotes for SQL + return json_str.replace("'", "''") + except (TypeError, ValueError) as e: + logger.warning(f"Failed to serialize metadata to JSON: {e}") + return "{}" + + def _safe_doc_id(self, doc_id: str) -> str: + """Ensure doc_id is safe for SQL and doesn't contain special characters.""" + if not doc_id: + return str(uuid.uuid4()) + # Remove or replace potentially problematic characters + safe_id = str(doc_id) + # Only allow alphanumeric, hyphens, underscores + safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_') + if not safe_id: # If all characters were removed + return str(uuid.uuid4()) + return safe_id[:255] # Limit length class ClickzettaVectorFactory(AbstractVectorFactory):