diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 181fe56f98..a3459117a8 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -3,11 +3,14 @@ import logging import queue import threading import uuid -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING import clickzetta # type: ignore from pydantic import BaseModel, model_validator +if TYPE_CHECKING: + from clickzetta import Connection + from configs import dify_config from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -79,7 +82,7 @@ class ClickzettaVector(BaseVector): super().__init__(collection_name) self._config = config self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name - self._connection = None + self._connection: Optional["Connection"] = None self._init_connection() self._init_write_queue() @@ -96,10 +99,11 @@ class ClickzettaVector(BaseVector): ) # Set session parameters for better string handling - with self._connection.cursor() as cursor: - # Use quote mode for string literal escaping to handle quotes better - cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") - logger.info("Set string literal escape mode to 'quote' for better quote handling") + if self._connection is not None: + with self._connection.cursor() as cursor: + # Use quote mode for string literal escaping to handle quotes better + cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") + logger.info("Set string literal escape mode to 'quote' for better quote handling") @classmethod def _init_write_queue(cls): @@ -117,20 +121,23 @@ class ClickzettaVector(BaseVector): while not cls._shutdown: try: # Get task from queue with timeout - task = cls._write_queue.get(timeout=1) - if task is None: # Shutdown signal - break + if cls._write_queue is not None: + task = cls._write_queue.get(timeout=1) + if task is None: # Shutdown signal + break - # Execute the write task - func, args, kwargs, result_queue = task - try: - result = func(*args, **kwargs) - result_queue.put((True, result)) - except Exception as e: - logger.exception("Write task failed") - result_queue.put((False, e)) - finally: - cls._write_queue.task_done() + # Execute the write task + func, args, kwargs, result_queue = task + try: + result = func(*args, **kwargs) + result_queue.put((True, result)) + except Exception as e: + logger.exception("Write task failed") + result_queue.put((False, e)) + finally: + cls._write_queue.task_done() + else: + break except queue.Empty: continue except Exception as e: @@ -141,7 +148,7 @@ class ClickzettaVector(BaseVector): if ClickzettaVector._write_queue is None: raise RuntimeError("Write queue not initialized") - result_queue = queue.Queue() + result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue() ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) # Wait for result @@ -154,10 +161,17 @@ class ClickzettaVector(BaseVector): """Return the vector database type.""" return "clickzetta" + def _ensure_connection(self) -> "Connection": + """Ensure connection is available and return it.""" + if self._connection is None: + raise RuntimeError("Database connection not initialized") + return self._connection + def _table_exists(self) -> bool: """Check if the table exists.""" try: - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") return True except Exception as e: @@ -197,7 +211,8 @@ class ClickzettaVector(BaseVector): ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' """ - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute(create_table_sql) logger.info(f"Created table {self._config.schema_name}.{self._table_name}") @@ -369,7 +384,8 @@ class ClickzettaVector(BaseVector): f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" ) - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: try: cursor.executemany(insert_sql, data_rows) logger.info( @@ -385,7 +401,8 @@ class ClickzettaVector(BaseVector): def text_exists(self, id: str) -> bool: """Check if a document exists by ID.""" safe_id = self._safe_doc_id(id) - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute( f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id] @@ -413,7 +430,8 @@ class ClickzettaVector(BaseVector): id_list = ",".join(f"'{id}'" for id in safe_ids) sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute(sql) def delete_by_metadata_field(self, key: str, value: str) -> None: @@ -428,7 +446,8 @@ class ClickzettaVector(BaseVector): def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: """Implementation of delete by metadata field (executed in write worker thread).""" - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: # Using JSON path to filter with parameterized query # Note: JSON path requires literal key name, cannot be parameterized # Use json_extract_string function for ClickZetta compatibility @@ -488,7 +507,8 @@ class ClickzettaVector(BaseVector): """ documents = [] - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute(search_sql) results = cursor.fetchall() @@ -572,7 +592,8 @@ class ClickzettaVector(BaseVector): """ documents = [] - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: try: cursor.execute(search_sql) results = cursor.fetchall() @@ -649,7 +670,8 @@ class ClickzettaVector(BaseVector): """ documents = [] - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute(search_sql) results = cursor.fetchall() @@ -689,7 +711,8 @@ class ClickzettaVector(BaseVector): def delete(self) -> None: """Delete the entire collection.""" - with self._connection.cursor() as cursor: + connection = self._ensure_connection() + with connection.cursor() as cursor: cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") @@ -718,13 +741,13 @@ class ClickzettaVectorFactory(AbstractVectorFactory): """Initialize a Clickzetta vector instance.""" # Get configuration from environment variables or dataset config config = ClickzettaConfig( - username=dify_config.CLICKZETTA_USERNAME, - password=dify_config.CLICKZETTA_PASSWORD, - instance=dify_config.CLICKZETTA_INSTANCE, - service=dify_config.CLICKZETTA_SERVICE, - workspace=dify_config.CLICKZETTA_WORKSPACE, - vcluster=dify_config.CLICKZETTA_VCLUSTER, - schema_name=dify_config.CLICKZETTA_SCHEMA, + username=dify_config.CLICKZETTA_USERNAME or "", + password=dify_config.CLICKZETTA_PASSWORD or "", + instance=dify_config.CLICKZETTA_INSTANCE or "", + service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com", + workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start", + vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap", + schema_name=dify_config.CLICKZETTA_SCHEMA or "dify", batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100, enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True, analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",