Fix MyPy type checking errors in ClickZetta vector implementation

- Add proper type annotations for Connection from clickzetta module
- Implement _ensure_connection() method to handle None connection checks
- Fix all database cursor access patterns to use proper null checking
- Add type annotation for result queue in _execute_write method
- Resolve factory method configuration issues with None value handling

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent f57fa13f1b
commit c3851595d0

@ -3,11 +3,14 @@ import logging
import queue
import threading
import uuid
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from clickzetta import Connection
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@ -79,7 +82,7 @@ class ClickzettaVector(BaseVector):
super().__init__(collection_name)
self._config = config
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
self._connection = None
self._connection: Optional["Connection"] = None
self._init_connection()
self._init_write_queue()
@ -96,10 +99,11 @@ class ClickzettaVector(BaseVector):
)
# Set session parameters for better string handling
with self._connection.cursor() as cursor:
# Use quote mode for string literal escaping to handle quotes better
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling")
if self._connection is not None:
with self._connection.cursor() as cursor:
# Use quote mode for string literal escaping to handle quotes better
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling")
@classmethod
def _init_write_queue(cls):
@ -117,20 +121,23 @@ class ClickzettaVector(BaseVector):
while not cls._shutdown:
try:
# Get task from queue with timeout
task = cls._write_queue.get(timeout=1)
if task is None: # Shutdown signal
break
if cls._write_queue is not None:
task = cls._write_queue.get(timeout=1)
if task is None: # Shutdown signal
break
# Execute the write task
func, args, kwargs, result_queue = task
try:
result = func(*args, **kwargs)
result_queue.put((True, result))
except Exception as e:
logger.exception("Write task failed")
result_queue.put((False, e))
finally:
cls._write_queue.task_done()
# Execute the write task
func, args, kwargs, result_queue = task
try:
result = func(*args, **kwargs)
result_queue.put((True, result))
except Exception as e:
logger.exception("Write task failed")
result_queue.put((False, e))
finally:
cls._write_queue.task_done()
else:
break
except queue.Empty:
continue
except Exception as e:
@ -141,7 +148,7 @@ class ClickzettaVector(BaseVector):
if ClickzettaVector._write_queue is None:
raise RuntimeError("Write queue not initialized")
result_queue = queue.Queue()
result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue()
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
# Wait for result
@ -154,10 +161,17 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type."""
return "clickzetta"
def _ensure_connection(self) -> "Connection":
"""Ensure connection is available and return it."""
if self._connection is None:
raise RuntimeError("Database connection not initialized")
return self._connection
def _table_exists(self) -> bool:
"""Check if the table exists."""
try:
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
return True
except Exception as e:
@ -197,7 +211,8 @@ class ClickzettaVector(BaseVector):
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
"""
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(create_table_sql)
logger.info(f"Created table {self._config.schema_name}.{self._table_name}")
@ -369,7 +384,8 @@ class ClickzettaVector(BaseVector):
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
)
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
try:
cursor.executemany(insert_sql, data_rows)
logger.info(
@ -385,7 +401,8 @@ class ClickzettaVector(BaseVector):
def text_exists(self, id: str) -> bool:
"""Check if a document exists by ID."""
safe_id = self._safe_doc_id(id)
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
[safe_id]
@ -413,7 +430,8 @@ class ClickzettaVector(BaseVector):
id_list = ",".join(f"'{id}'" for id in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(sql)
def delete_by_metadata_field(self, key: str, value: str) -> None:
@ -428,7 +446,8 @@ class ClickzettaVector(BaseVector):
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread)."""
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility
@ -488,7 +507,8 @@ class ClickzettaVector(BaseVector):
"""
documents = []
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(search_sql)
results = cursor.fetchall()
@ -572,7 +592,8 @@ class ClickzettaVector(BaseVector):
"""
documents = []
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
try:
cursor.execute(search_sql)
results = cursor.fetchall()
@ -649,7 +670,8 @@ class ClickzettaVector(BaseVector):
"""
documents = []
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(search_sql)
results = cursor.fetchall()
@ -689,7 +711,8 @@ class ClickzettaVector(BaseVector):
def delete(self) -> None:
"""Delete the entire collection."""
with self._connection.cursor() as cursor:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
@ -718,13 +741,13 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
"""Initialize a Clickzetta vector instance."""
# Get configuration from environment variables or dataset config
config = ClickzettaConfig(
username=dify_config.CLICKZETTA_USERNAME,
password=dify_config.CLICKZETTA_PASSWORD,
instance=dify_config.CLICKZETTA_INSTANCE,
service=dify_config.CLICKZETTA_SERVICE,
workspace=dify_config.CLICKZETTA_WORKSPACE,
vcluster=dify_config.CLICKZETTA_VCLUSTER,
schema_name=dify_config.CLICKZETTA_SCHEMA,
username=dify_config.CLICKZETTA_USERNAME or "",
password=dify_config.CLICKZETTA_PASSWORD or "",
instance=dify_config.CLICKZETTA_INSTANCE or "",
service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com",
workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start",
vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap",
schema_name=dify_config.CLICKZETTA_SCHEMA or "dify",
batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100,
enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True,
analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",

Loading…
Cancel
Save