Fix MyPy type checking errors in ClickZetta vector implementation

- Add proper type annotations for Connection from clickzetta module
- Implement _ensure_connection() method to handle None connection checks
- Fix all database cursor access patterns to use proper null checking
- Add type annotation for result queue in _execute_write method
- Resolve factory method configuration issues with None value handling

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent f57fa13f1b
commit c3851595d0

@ -3,11 +3,14 @@ import logging
import queue import queue
import threading import threading
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional, TYPE_CHECKING
import clickzetta # type: ignore import clickzetta # type: ignore
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from clickzetta import Connection
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
@ -79,7 +82,7 @@ class ClickzettaVector(BaseVector):
super().__init__(collection_name) super().__init__(collection_name)
self._config = config self._config = config
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
self._connection = None self._connection: Optional["Connection"] = None
self._init_connection() self._init_connection()
self._init_write_queue() self._init_write_queue()
@ -96,10 +99,11 @@ class ClickzettaVector(BaseVector):
) )
# Set session parameters for better string handling # Set session parameters for better string handling
with self._connection.cursor() as cursor: if self._connection is not None:
# Use quote mode for string literal escaping to handle quotes better with self._connection.cursor() as cursor:
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") # Use quote mode for string literal escaping to handle quotes better
logger.info("Set string literal escape mode to 'quote' for better quote handling") cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling")
@classmethod @classmethod
def _init_write_queue(cls): def _init_write_queue(cls):
@ -117,20 +121,23 @@ class ClickzettaVector(BaseVector):
while not cls._shutdown: while not cls._shutdown:
try: try:
# Get task from queue with timeout # Get task from queue with timeout
task = cls._write_queue.get(timeout=1) if cls._write_queue is not None:
if task is None: # Shutdown signal task = cls._write_queue.get(timeout=1)
break if task is None: # Shutdown signal
break
# Execute the write task # Execute the write task
func, args, kwargs, result_queue = task func, args, kwargs, result_queue = task
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
result_queue.put((True, result)) result_queue.put((True, result))
except Exception as e: except Exception as e:
logger.exception("Write task failed") logger.exception("Write task failed")
result_queue.put((False, e)) result_queue.put((False, e))
finally: finally:
cls._write_queue.task_done() cls._write_queue.task_done()
else:
break
except queue.Empty: except queue.Empty:
continue continue
except Exception as e: except Exception as e:
@ -141,7 +148,7 @@ class ClickzettaVector(BaseVector):
if ClickzettaVector._write_queue is None: if ClickzettaVector._write_queue is None:
raise RuntimeError("Write queue not initialized") raise RuntimeError("Write queue not initialized")
result_queue = queue.Queue() result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue()
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
# Wait for result # Wait for result
@ -154,10 +161,17 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type.""" """Return the vector database type."""
return "clickzetta" return "clickzetta"
def _ensure_connection(self) -> "Connection":
"""Ensure connection is available and return it."""
if self._connection is None:
raise RuntimeError("Database connection not initialized")
return self._connection
def _table_exists(self) -> bool: def _table_exists(self) -> bool:
"""Check if the table exists.""" """Check if the table exists."""
try: try:
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
return True return True
except Exception as e: except Exception as e:
@ -197,7 +211,8 @@ class ClickzettaVector(BaseVector):
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
""" """
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(create_table_sql) cursor.execute(create_table_sql)
logger.info(f"Created table {self._config.schema_name}.{self._table_name}") logger.info(f"Created table {self._config.schema_name}.{self._table_name}")
@ -369,7 +384,8 @@ class ClickzettaVector(BaseVector):
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
) )
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
try: try:
cursor.executemany(insert_sql, data_rows) cursor.executemany(insert_sql, data_rows)
logger.info( logger.info(
@ -385,7 +401,8 @@ class ClickzettaVector(BaseVector):
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
"""Check if a document exists by ID.""" """Check if a document exists by ID."""
safe_id = self._safe_doc_id(id) safe_id = self._safe_doc_id(id)
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute( cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
[safe_id] [safe_id]
@ -413,7 +430,8 @@ class ClickzettaVector(BaseVector):
id_list = ",".join(f"'{id}'" for id in safe_ids) id_list = ",".join(f"'{id}'" for id in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(sql) cursor.execute(sql)
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:
@ -428,7 +446,8 @@ class ClickzettaVector(BaseVector):
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread).""" """Implementation of delete by metadata field (executed in write worker thread)."""
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
# Using JSON path to filter with parameterized query # Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized # Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility # Use json_extract_string function for ClickZetta compatibility
@ -488,7 +507,8 @@ class ClickzettaVector(BaseVector):
""" """
documents = [] documents = []
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(search_sql) cursor.execute(search_sql)
results = cursor.fetchall() results = cursor.fetchall()
@ -572,7 +592,8 @@ class ClickzettaVector(BaseVector):
""" """
documents = [] documents = []
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
try: try:
cursor.execute(search_sql) cursor.execute(search_sql)
results = cursor.fetchall() results = cursor.fetchall()
@ -649,7 +670,8 @@ class ClickzettaVector(BaseVector):
""" """
documents = [] documents = []
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(search_sql) cursor.execute(search_sql)
results = cursor.fetchall() results = cursor.fetchall()
@ -689,7 +711,8 @@ class ClickzettaVector(BaseVector):
def delete(self) -> None: def delete(self) -> None:
"""Delete the entire collection.""" """Delete the entire collection."""
with self._connection.cursor() as cursor: connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
@ -718,13 +741,13 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
"""Initialize a Clickzetta vector instance.""" """Initialize a Clickzetta vector instance."""
# Get configuration from environment variables or dataset config # Get configuration from environment variables or dataset config
config = ClickzettaConfig( config = ClickzettaConfig(
username=dify_config.CLICKZETTA_USERNAME, username=dify_config.CLICKZETTA_USERNAME or "",
password=dify_config.CLICKZETTA_PASSWORD, password=dify_config.CLICKZETTA_PASSWORD or "",
instance=dify_config.CLICKZETTA_INSTANCE, instance=dify_config.CLICKZETTA_INSTANCE or "",
service=dify_config.CLICKZETTA_SERVICE, service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com",
workspace=dify_config.CLICKZETTA_WORKSPACE, workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start",
vcluster=dify_config.CLICKZETTA_VCLUSTER, vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap",
schema_name=dify_config.CLICKZETTA_SCHEMA, schema_name=dify_config.CLICKZETTA_SCHEMA or "dify",
batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100, batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100,
enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True, enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True,
analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese", analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",

Loading…
Cancel
Save