diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 3d786ff5f5..b484f0cb6b 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -317,11 +317,12 @@ class ClickzettaVector(BaseVector): # Prepare batch insert values = [] for doc, embedding in zip(batch_docs, batch_embeddings): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - # For JSON column in Clickzetta, use JSON 'json_string' format - metadata_json = json.dumps(doc.metadata).replace("'", "''") # Escape single quotes + doc_id = self._safe_doc_id(doc.metadata.get("doc_id", str(uuid.uuid4()))) + # For JSON column in Clickzetta, use safe JSON formatting + metadata_json = self._escape_json_string(doc.metadata) embedding_str = self._format_vector(embedding) - values.append(f"('{doc_id}', '{self._escape_string(doc.page_content)}', " + escaped_content = self._escape_string(doc.page_content) + values.append(f"('{doc_id}', '{escaped_content}', " f"JSON '{metadata_json}', {embedding_str})") # Use regular INSERT - primary key will handle duplicates @@ -337,9 +338,10 @@ class ClickzettaVector(BaseVector): def text_exists(self, id: str) -> bool: """Check if a document exists by ID.""" + safe_id = self._safe_doc_id(id) with self._connection.cursor() as cursor: cursor.execute( - f"SELECT COUNT(*) FROM {self._config.schema}.{self._table_name} WHERE id = '{id}'" + f"SELECT COUNT(*) FROM {self._config.schema}.{self._table_name} WHERE id = '{safe_id}'" ) result = cursor.fetchone() return result[0] > 0 if result else False @@ -359,7 +361,8 @@ class ClickzettaVector(BaseVector): def _delete_by_ids_impl(self, ids: list[str]) -> None: """Implementation of delete by IDs (executed in write worker thread).""" - ids_str = ",".join(f"'{id}'" for id in ids) + safe_ids = [self._safe_doc_id(id) for id in ids] + ids_str = ",".join(f"'{id}'" for id in safe_ids) with self._connection.cursor() as cursor: cursor.execute( f"DELETE FROM {self._config.schema}.{self._table_name} WHERE id IN ({ids_str})" @@ -377,11 +380,14 @@ class ClickzettaVector(BaseVector): def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: """Implementation of delete by metadata field (executed in write worker thread).""" + # Safely escape the key and value + safe_key = self._escape_string(key) + safe_value = self._escape_string(value) with self._connection.cursor() as cursor: # Using JSON path to filter cursor.execute( f"DELETE FROM {self._config.schema}.{self._table_name} " - f"WHERE {Field.METADATA_KEY.value}->>'$.{key}' = '{value}'" + f"WHERE {Field.METADATA_KEY.value}->>'$.{safe_key}' = '{safe_value}'" ) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: @@ -393,7 +399,8 @@ class ClickzettaVector(BaseVector): # Build filter clause filter_clauses = [] if document_ids_filter: - doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) + safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") # Add distance threshold based on distance function @@ -457,7 +464,8 @@ class ClickzettaVector(BaseVector): # Build filter clause filter_clauses = [] if document_ids_filter: - doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) + safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") # Use match_all function for full-text search @@ -501,7 +509,8 @@ class ClickzettaVector(BaseVector): # Build filter clause filter_clauses = [] if document_ids_filter: - doc_ids_str = ",".join(f"'{id}'" for id in document_ids_filter) + safe_doc_ids = [self._escape_string(str(id)) for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) filter_clauses.append(f"{Field.METADATA_KEY.value}->>'$.document_id' IN ({doc_ids_str})") filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{self._escape_string(query)}%'") @@ -533,8 +542,17 @@ class ClickzettaVector(BaseVector): cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema}.{self._table_name}") def _escape_string(self, s: str) -> str: - """Escape single quotes in strings for SQL.""" - return s.replace("'", "''") + """Escape single quotes and other special characters for SQL.""" + if s is None: + return "" + # Replace single quotes and other potentially problematic characters + s = str(s) + s = s.replace("'", "''") # Escape single quotes + s = s.replace("\\", "\\\\") # Escape backslashes + s = s.replace("\n", "\\n") # Escape newlines + s = s.replace("\r", "\\r") # Escape carriage returns + s = s.replace("\t", "\\t") # Escape tabs + return s def _format_vector(self, vector: list[float]) -> str: """Safely format vector for SQL, handling special float values.""" @@ -554,6 +572,28 @@ class ClickzettaVector(BaseVector): else: safe_values.append("0.0") return f"VECTOR({','.join(safe_values)})" + + def _escape_json_string(self, obj: dict) -> str: + """Safely format JSON for SQL, escaping special characters.""" + try: + json_str = json.dumps(obj, ensure_ascii=True) + # Escape single quotes for SQL + return json_str.replace("'", "''") + except (TypeError, ValueError) as e: + logger.warning(f"Failed to serialize metadata to JSON: {e}") + return "{}" + + def _safe_doc_id(self, doc_id: str) -> str: + """Ensure doc_id is safe for SQL and doesn't contain special characters.""" + if not doc_id: + return str(uuid.uuid4()) + # Remove or replace potentially problematic characters + safe_id = str(doc_id) + # Only allow alphanumeric, hyphens, underscores + safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_') + if not safe_id: # If all characters were removed + return str(uuid.uuid4()) + return safe_id[:255] # Limit length class ClickzettaVectorFactory(AbstractVectorFactory):