fix: resolve CI linting issues and add missing newlines

- Fix all line length issues (120 character limit)
- Remove all trailing whitespace
- Add missing newlines at end of files
- Add CLICKZETTA_VOLUME_DIFY_PREFIX environment variable to docker-compose.yaml
- Ensure proper code formatting for all ClickZetta files

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent f3b1bdc04f
commit f57fa13f1b

@ -62,4 +62,4 @@ class ClickZettaVolumeStorageConfig(BaseSettings):
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field( CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
description="Directory prefix for User Volume to organize Dify files", description="Directory prefix for User Volume to organize Dify files",
default="dify_km", default="dify_km",
) )

@ -66,4 +66,5 @@ class ClickzettaConfig(BaseModel):
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
description="Distance function for vector similarity: l2_distance or cosine_distance", description="Distance function for vector similarity: l2_distance or cosine_distance",
default="cosine_distance", default="cosine_distance",
) )

@ -1 +1 @@
# Clickzetta Vector Database Integration for Dify # Clickzetta Vector Database Integration for Dify

@ -68,7 +68,7 @@ class ClickzettaVector(BaseVector):
""" """
Clickzetta vector storage implementation. Clickzetta vector storage implementation.
""" """
# Class-level write queue and lock for serializing writes # Class-level write queue and lock for serializing writes
_write_queue: Optional[queue.Queue] = None _write_queue: Optional[queue.Queue] = None
_write_thread: Optional[threading.Thread] = None _write_thread: Optional[threading.Thread] = None
@ -94,13 +94,13 @@ class ClickzettaVector(BaseVector):
vcluster=self._config.vcluster, vcluster=self._config.vcluster,
schema=self._config.schema_name schema=self._config.schema_name
) )
# Set session parameters for better string handling # Set session parameters for better string handling
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
# Use quote mode for string literal escaping to handle quotes better # Use quote mode for string literal escaping to handle quotes better
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling") logger.info("Set string literal escape mode to 'quote' for better quote handling")
@classmethod @classmethod
def _init_write_queue(cls): def _init_write_queue(cls):
"""Initialize the write queue and worker thread.""" """Initialize the write queue and worker thread."""
@ -110,7 +110,7 @@ class ClickzettaVector(BaseVector):
cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True) cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True)
cls._write_thread.start() cls._write_thread.start()
logger.info("Started Clickzetta write worker thread") logger.info("Started Clickzetta write worker thread")
@classmethod @classmethod
def _write_worker(cls): def _write_worker(cls):
"""Worker thread that processes write tasks sequentially.""" """Worker thread that processes write tasks sequentially."""
@ -120,7 +120,7 @@ class ClickzettaVector(BaseVector):
task = cls._write_queue.get(timeout=1) task = cls._write_queue.get(timeout=1)
if task is None: # Shutdown signal if task is None: # Shutdown signal
break break
# Execute the write task # Execute the write task
func, args, kwargs, result_queue = task func, args, kwargs, result_queue = task
try: try:
@ -135,15 +135,15 @@ class ClickzettaVector(BaseVector):
continue continue
except Exception as e: except Exception as e:
logger.exception("Write worker error") logger.exception("Write worker error")
def _execute_write(self, func, *args, **kwargs): def _execute_write(self, func, *args, **kwargs):
"""Execute a write operation through the queue.""" """Execute a write operation through the queue."""
if ClickzettaVector._write_queue is None: if ClickzettaVector._write_queue is None:
raise RuntimeError("Write queue not initialized") raise RuntimeError("Write queue not initialized")
result_queue = queue.Queue() result_queue = queue.Queue()
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
# Wait for result # Wait for result
success, result = result_queue.get() success, result = result_queue.get()
if not success: if not success:
@ -171,18 +171,18 @@ class ClickzettaVector(BaseVector):
"""Create the collection and add initial documents.""" """Create the collection and add initial documents."""
# Execute table creation through write queue to avoid concurrent conflicts # Execute table creation through write queue to avoid concurrent conflicts
self._execute_write(self._create_table_and_indexes, embeddings) self._execute_write(self._create_table_and_indexes, embeddings)
# Add initial texts # Add initial texts
if texts: if texts:
self.add_texts(texts, embeddings, **kwargs) self.add_texts(texts, embeddings, **kwargs)
def _create_table_and_indexes(self, embeddings: list[list[float]]): def _create_table_and_indexes(self, embeddings: list[list[float]]):
"""Create table and indexes (executed in write worker thread).""" """Create table and indexes (executed in write worker thread)."""
# Check if table already exists to avoid unnecessary index creation # Check if table already exists to avoid unnecessary index creation
if self._table_exists(): if self._table_exists():
logger.info(f"Table {self._config.schema_name}.{self._table_name} already exists, skipping creation") logger.info(f"Table {self._config.schema_name}.{self._table_name} already exists, skipping creation")
return return
# Create table with vector and metadata columns # Create table with vector and metadata columns
dimension = len(embeddings[0]) if embeddings else 768 dimension = len(embeddings[0]) if embeddings else 768
@ -191,7 +191,8 @@ class ClickzettaVector(BaseVector):
id STRING NOT NULL COMMENT 'Unique document identifier', id STRING NOT NULL COMMENT 'Unique document identifier',
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT 'High-dimensional embedding vector for semantic similarity search', {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
'High-dimensional embedding vector for semantic similarity search',
PRIMARY KEY (id) PRIMARY KEY (id)
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
""" """
@ -211,7 +212,7 @@ class ClickzettaVector(BaseVector):
"""Create HNSW vector index for similarity search.""" """Create HNSW vector index for similarity search."""
# Use a fixed index name based on table and column name # Use a fixed index name based on table and column name
index_name = f"idx_{self._table_name}_vector" index_name = f"idx_{self._table_name}_vector"
# First check if an index already exists on this column # First check if an index already exists on this column
try: try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
@ -223,7 +224,7 @@ class ClickzettaVector(BaseVector):
return return
except Exception as e: except Exception as e:
logger.warning(f"Failed to check existing indexes: {e}") logger.warning(f"Failed to check existing indexes: {e}")
index_sql = f""" index_sql = f"""
CREATE VECTOR INDEX IF NOT EXISTS {index_name} CREATE VECTOR INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
@ -239,8 +240,8 @@ class ClickzettaVector(BaseVector):
logger.info(f"Created vector index: {index_name}") logger.info(f"Created vector index: {index_name}")
except Exception as e: except Exception as e:
error_msg = str(e).lower() error_msg = str(e).lower()
if ("already exists" in error_msg or if ("already exists" in error_msg or
"already has index" in error_msg or "already has index" in error_msg or
"with the same type" in error_msg): "with the same type" in error_msg):
logger.info(f"Vector index already exists: {e}") logger.info(f"Vector index already exists: {e}")
else: else:
@ -251,7 +252,7 @@ class ClickzettaVector(BaseVector):
"""Create inverted index for full-text search.""" """Create inverted index for full-text search."""
# Use a fixed index name based on table name to avoid duplicates # Use a fixed index name based on table name to avoid duplicates
index_name = f"idx_{self._table_name}_text" index_name = f"idx_{self._table_name}_text"
# Check if an inverted index already exists on this column # Check if an inverted index already exists on this column
try: try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
@ -259,14 +260,14 @@ class ClickzettaVector(BaseVector):
for idx in existing_indexes: for idx in existing_indexes:
idx_str = str(idx).lower() idx_str = str(idx).lower()
# More precise check: look for inverted index specifically on the content column # More precise check: look for inverted index specifically on the content column
if ("inverted" in idx_str and if ("inverted" in idx_str and
Field.CONTENT_KEY.value.lower() in idx_str and Field.CONTENT_KEY.value.lower() in idx_str and
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)): (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
logger.info(f"Inverted index already exists on column {Field.CONTENT_KEY.value}: {idx}") logger.info(f"Inverted index already exists on column {Field.CONTENT_KEY.value}: {idx}")
return return
except Exception as e: except Exception as e:
logger.warning(f"Failed to check existing indexes: {e}") logger.warning(f"Failed to check existing indexes: {e}")
index_sql = f""" index_sql = f"""
CREATE INVERTED INDEX IF NOT EXISTS {index_name} CREATE INVERTED INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
@ -281,8 +282,8 @@ class ClickzettaVector(BaseVector):
except Exception as e: except Exception as e:
error_msg = str(e).lower() error_msg = str(e).lower()
# Handle ClickZetta specific error messages # Handle ClickZetta specific error messages
if (("already exists" in error_msg or if (("already exists" in error_msg or
"already has index" in error_msg or "already has index" in error_msg or
"with the same type" in error_msg or "with the same type" in error_msg or
"cannot create inverted index" in error_msg) and "cannot create inverted index" in error_msg) and
"already has index" in error_msg): "already has index" in error_msg):
@ -313,44 +314,44 @@ class ClickzettaVector(BaseVector):
for i in range(0, len(documents), batch_size): for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size] batch_docs = documents[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size] batch_embeddings = embeddings[i:i + batch_size]
# Execute batch insert through write queue # Execute batch insert through write queue
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]], def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
batch_index: int, batch_size: int, total_batches: int): batch_index: int, batch_size: int, total_batches: int):
"""Insert a batch of documents using parameterized queries (executed in write worker thread).""" """Insert a batch of documents using parameterized queries (executed in write worker thread)."""
if not batch_docs or not batch_embeddings: if not batch_docs or not batch_embeddings:
logger.warning("Empty batch provided, skipping insertion") logger.warning("Empty batch provided, skipping insertion")
return return
if len(batch_docs) != len(batch_embeddings): if len(batch_docs) != len(batch_embeddings):
logger.error(f"Mismatch between docs ({len(batch_docs)}) and embeddings ({len(batch_embeddings)})") logger.error(f"Mismatch between docs ({len(batch_docs)}) and embeddings ({len(batch_embeddings)})")
return return
# Prepare data for parameterized insertion # Prepare data for parameterized insertion
data_rows = [] data_rows = []
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
for doc, embedding in zip(batch_docs, batch_embeddings): for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases # Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if doc.metadata else {} metadata = doc.metadata if doc.metadata else {}
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
metadata = {} metadata = {}
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
# Fast path for JSON serialization # Fast path for JSON serialization
try: try:
metadata_json = json.dumps(metadata, ensure_ascii=True) metadata_json = json.dumps(metadata, ensure_ascii=True)
except (TypeError, ValueError): except (TypeError, ValueError):
logger.warning("JSON serialization failed, using empty dict") logger.warning("JSON serialization failed, using empty dict")
metadata_json = "{}" metadata_json = "{}"
content = doc.page_content or "" content = doc.page_content or ""
# According to ClickZetta docs, vector should be formatted as array string # According to ClickZetta docs, vector should be formatted as array string
# for external systems: '[1.0, 2.0, 3.0]' # for external systems: '[1.0, 2.0, 3.0]'
vector_str = '[' + ','.join(map(str, embedding)) + ']' vector_str = '[' + ','.join(map(str, embedding)) + ']'
data_rows.append([doc_id, content, metadata_json, vector_str]) data_rows.append([doc_id, content, metadata_json, vector_str])
@ -359,17 +360,22 @@ class ClickzettaVector(BaseVector):
if not data_rows: if not data_rows:
logger.warning(f"No valid documents to insert in batch {batch_index // batch_size + 1}/{total_batches}") logger.warning(f"No valid documents to insert in batch {batch_index // batch_size + 1}/{total_batches}")
return return
# Use parameterized INSERT with executemany for better performance and security # Use parameterized INSERT with executemany for better performance and security
# Cast JSON and VECTOR in SQL, pass raw data as parameters # Cast JSON and VECTOR in SQL, pass raw data as parameters
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
insert_sql = f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" insert_sql = (
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
)
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
try: try:
cursor.executemany(insert_sql, data_rows) cursor.executemany(insert_sql, data_rows)
logger.info(f"Inserted batch {batch_index // batch_size + 1}/{total_batches} " logger.info(
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)") f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
)
except Exception as e: except Exception as e:
logger.exception(f"Parameterized SQL execution failed for {len(data_rows)} documents: {e}") logger.exception(f"Parameterized SQL execution failed for {len(data_rows)} documents: {e}")
logger.exception(f"SQL template: {insert_sql}") logger.exception(f"SQL template: {insert_sql}")
@ -399,14 +405,14 @@ class ClickzettaVector(BaseVector):
# Execute delete through write queue # Execute delete through write queue
self._execute_write(self._delete_by_ids_impl, ids) self._execute_write(self._delete_by_ids_impl, ids)
def _delete_by_ids_impl(self, ids: list[str]) -> None: def _delete_by_ids_impl(self, ids: list[str]) -> None:
"""Implementation of delete by IDs (executed in write worker thread).""" """Implementation of delete by IDs (executed in write worker thread)."""
safe_ids = [self._safe_doc_id(id) for id in ids] safe_ids = [self._safe_doc_id(id) for id in ids]
# Create properly escaped string literals for SQL # Create properly escaped string literals for SQL
id_list = ",".join(f"'{id}'" for id in safe_ids) id_list = ",".join(f"'{id}'" for id in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
cursor.execute(sql) cursor.execute(sql)
@ -419,7 +425,7 @@ class ClickzettaVector(BaseVector):
# Execute delete through write queue # Execute delete through write queue
self._execute_write(self._delete_by_metadata_field_impl, key, value) self._execute_write(self._delete_by_metadata_field_impl, key, value)
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread).""" """Implementation of delete by metadata field (executed in write worker thread)."""
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
@ -435,7 +441,7 @@ class ClickzettaVector(BaseVector):
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 0.0) score_threshold = kwargs.get("score_threshold", 0.0)
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow) # Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {}) filter_param = kwargs.get("filter", {})
@ -445,8 +451,10 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility # Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})") filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table # No need for dataset_id filter since each dataset has its own table
# Add distance threshold based on distance function # Add distance threshold based on distance function
@ -489,11 +497,11 @@ class ClickzettaVector(BaseVector):
try: try:
if row[2]: if row[2]:
metadata = json.loads(row[2]) metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again # If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str): if isinstance(metadata, str):
metadata = json.loads(metadata) metadata = json.loads(metadata)
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
metadata = {} metadata = {}
else: else:
@ -504,14 +512,14 @@ class ClickzettaVector(BaseVector):
import re import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set # Ensure required fields are set
metadata["doc_id"] = row[0] # segment id metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents) # Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata: if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id metadata["document_id"] = row[0] # fallback to segment id
# Add score based on distance # Add score based on distance
if self._config.vector_distance_function == "cosine_distance": if self._config.vector_distance_function == "cosine_distance":
metadata["score"] = 1 - (row[3] / 2) metadata["score"] = 1 - (row[3] / 2)
@ -531,7 +539,7 @@ class ClickzettaVector(BaseVector):
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow) # Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {}) filter_param = kwargs.get("filter", {})
@ -541,8 +549,10 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility # Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})") filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table # No need for dataset_id filter since each dataset has its own table
# Use match_all function for full-text search # Use match_all function for full-text search
@ -572,11 +582,11 @@ class ClickzettaVector(BaseVector):
try: try:
if row[2]: if row[2]:
metadata = json.loads(row[2]) metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again # If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str): if isinstance(metadata, str):
metadata = json.loads(metadata) metadata = json.loads(metadata)
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
metadata = {} metadata = {}
else: else:
@ -587,14 +597,14 @@ class ClickzettaVector(BaseVector):
import re import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set # Ensure required fields are set
metadata["doc_id"] = row[0] # segment id metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents) # Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata: if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id metadata["document_id"] = row[0] # fallback to segment id
# Add a relevance score for full-text search # Add a relevance score for full-text search
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata) doc = Document(page_content=row[1], metadata=metadata)
@ -610,7 +620,7 @@ class ClickzettaVector(BaseVector):
"""Fallback search using LIKE operator.""" """Fallback search using LIKE operator."""
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow) # Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {}) filter_param = kwargs.get("filter", {})
@ -620,8 +630,10 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility # Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})") filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table # No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause # Use simple quote escaping for LIKE clause
@ -646,11 +658,11 @@ class ClickzettaVector(BaseVector):
try: try:
if row[2]: if row[2]:
metadata = json.loads(row[2]) metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again # If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str): if isinstance(metadata, str):
metadata = json.loads(metadata) metadata = json.loads(metadata)
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
metadata = {} metadata = {}
else: else:
@ -661,14 +673,14 @@ class ClickzettaVector(BaseVector):
import re import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set # Ensure required fields are set
metadata["doc_id"] = row[0] # segment id metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents) # Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata: if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id metadata["document_id"] = row[0] # fallback to segment id
metadata["score"] = 0.5 # Lower score for LIKE search metadata["score"] = 0.5 # Lower score for LIKE search
doc = Document(page_content=row[1], metadata=metadata) doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc) documents.append(doc)
@ -680,11 +692,11 @@ class ClickzettaVector(BaseVector):
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
def _format_vector_simple(self, vector: list[float]) -> str: def _format_vector_simple(self, vector: list[float]) -> str:
"""Simple vector formatting for SQL queries.""" """Simple vector formatting for SQL queries."""
return ','.join(map(str, vector)) return ','.join(map(str, vector))
def _safe_doc_id(self, doc_id: str) -> str: def _safe_doc_id(self, doc_id: str) -> str:
"""Ensure doc_id is safe for SQL and doesn't contain special characters.""" """Ensure doc_id is safe for SQL and doesn't contain special characters."""
if not doc_id: if not doc_id:
@ -696,7 +708,7 @@ class ClickzettaVector(BaseVector):
if not safe_id: # If all characters were removed if not safe_id: # If all characters were removed
return str(uuid.uuid4()) return str(uuid.uuid4())
return safe_id[:255] # Limit length return safe_id[:255] # Limit length
class ClickzettaVectorFactory(AbstractVectorFactory): class ClickzettaVectorFactory(AbstractVectorFactory):
@ -724,3 +736,4 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
return ClickzettaVector(collection_name=collection_name, config=config) return ClickzettaVector(collection_name=collection_name, config=config)

@ -2,4 +2,4 @@
from .clickzetta_volume_storage import ClickZettaVolumeStorage from .clickzetta_volume_storage import ClickZettaVolumeStorage
__all__ = ["ClickZettaVolumeStorage"] __all__ = ["ClickZettaVolumeStorage"]

@ -526,4 +526,4 @@ class ClickZettaVolumeStorage(BaseStorage):
except Exception as e: except Exception as e:
logger.error(f"Error scanning path {path}: {e}") logger.error(f"Error scanning path {path}: {e}")
return [] return []

@ -508,4 +508,4 @@ class FileLifecycleManager:
except Exception as e: except Exception as e:
logger.error(f"Permission check failed for {filename} operation {operation}: {e}") logger.error(f"Permission check failed for {filename} operation {operation}: {e}")
# 安全默认:权限检查失败时拒绝访问 # 安全默认:权限检查失败时拒绝访问
return False return False

@ -22,10 +22,10 @@ class VolumePermission(Enum):
class VolumePermissionManager: class VolumePermissionManager:
"""Volume权限管理器""" """Volume权限管理器"""
def __init__(self, connection_or_config, volume_type: str = None, volume_name: Optional[str] = None): def __init__(self, connection_or_config, volume_type: str = None, volume_name: Optional[str] = None):
"""初始化权限管理器 """初始化权限管理器
Args: Args:
connection_or_config: ClickZetta连接对象或配置字典 connection_or_config: ClickZetta连接对象或配置字典
volume_type: Volume类型 (user|table|external) volume_type: Volume类型 (user|table|external)
@ -52,22 +52,22 @@ class VolumePermissionManager:
self._connection = connection_or_config self._connection = connection_or_config
self._volume_type = volume_type self._volume_type = volume_type
self._volume_name = volume_name self._volume_name = volume_name
if not self._connection: if not self._connection:
raise ValueError("Valid connection or config is required") raise ValueError("Valid connection or config is required")
if not self._volume_type: if not self._volume_type:
raise ValueError("volume_type is required") raise ValueError("volume_type is required")
self._permission_cache: Dict[str, Set[str]] = {} self._permission_cache: Dict[str, Set[str]] = {}
self._current_username = None # 将从连接中获取当前用户名 self._current_username = None # 将从连接中获取当前用户名
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
"""检查用户是否有执行特定操作的权限 """检查用户是否有执行特定操作的权限
Args: Args:
operation: 要执行的操作类型 operation: 要执行的操作类型
dataset_id: 数据集ID (用于table volume) dataset_id: 数据集ID (用于table volume)
Returns: Returns:
True if user has permission, False otherwise True if user has permission, False otherwise
""" """
@ -81,14 +81,14 @@ class VolumePermissionManager:
else: else:
logger.warning(f"Unknown volume type: {self._volume_type}") logger.warning(f"Unknown volume type: {self._volume_type}")
return False return False
except Exception as e: except Exception as e:
logger.error(f"Permission check failed: {e}") logger.error(f"Permission check failed: {e}")
return False return False
def _check_user_volume_permission(self, operation: VolumePermission) -> bool: def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
"""检查User Volume权限 """检查User Volume权限
User Volume权限规则: User Volume权限规则:
- 用户对自己的User Volume有全部权限 - 用户对自己的User Volume有全部权限
- 只要用户能够连接到ClickZetta就默认具有User Volume的基本权限 - 只要用户能够连接到ClickZetta就默认具有User Volume的基本权限
@ -97,29 +97,34 @@ class VolumePermissionManager:
try: try:
# 获取当前用户名 # 获取当前用户名
current_user = self._get_current_username() current_user = self._get_current_username()
# 检查基本连接状态 # 检查基本连接状态
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
# 简单的连接测试,如果能执行查询说明用户有基本权限 # 简单的连接测试,如果能执行查询说明用户有基本权限
cursor.execute("SELECT 1") cursor.execute("SELECT 1")
result = cursor.fetchone() result = cursor.fetchone()
if result: if result:
logger.debug(f"User Volume permission check for {current_user}, operation {operation.name}: granted (basic connection verified)") logger.debug(
f"User Volume permission check for {current_user}, operation {operation.name}: "
f"granted (basic connection verified)"
)
return True return True
else: else:
logger.warning(f"User Volume permission check failed: cannot verify basic connection for {current_user}") logger.warning(
f"User Volume permission check failed: cannot verify basic connection for {current_user}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"User Volume permission check failed: {e}") logger.error(f"User Volume permission check failed: {e}")
# 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示 # 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示
logger.info(f"User Volume permission check failed, but permission checking is disabled in this version") logger.info(f"User Volume permission check failed, but permission checking is disabled in this version")
return False return False
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
"""检查Table Volume权限 """检查Table Volume权限
Table Volume权限规则: Table Volume权限规则:
- Table Volume权限继承对应表的权限 - Table Volume权限继承对应表的权限
- SELECT权限 -> 可以READ/LIST文件 - SELECT权限 -> 可以READ/LIST文件
@ -128,29 +133,29 @@ class VolumePermissionManager:
if not dataset_id: if not dataset_id:
logger.warning("dataset_id is required for table volume permission check") logger.warning("dataset_id is required for table volume permission check")
return False return False
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
try: try:
# 检查表权限 # 检查表权限
permissions = self._get_table_permissions(table_name) permissions = self._get_table_permissions(table_name)
required_permissions = set(operation.value.split(",")) required_permissions = set(operation.value.split(","))
# 检查是否有所需的所有权限 # 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions) has_permission = required_permissions.issubset(permissions)
logger.debug(f"Table Volume permission check for {table_name}, operation {operation.name}: " logger.debug(f"Table Volume permission check for {table_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}") f"required={required_permissions}, has={permissions}, granted={has_permission}")
return has_permission return has_permission
except Exception as e: except Exception as e:
logger.error(f"Table volume permission check failed for {table_name}: {e}") logger.error(f"Table volume permission check failed for {table_name}: {e}")
return False return False
def _check_external_volume_permission(self, operation: VolumePermission) -> bool: def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
"""检查External Volume权限 """检查External Volume权限
External Volume权限规则: External Volume权限规则:
- 尝试获取对External Volume的权限 - 尝试获取对External Volume的权限
- 如果权限检查失败进行备选验证 - 如果权限检查失败进行备选验证
@ -159,29 +164,29 @@ class VolumePermissionManager:
if not self._volume_name: if not self._volume_name:
logger.warning("volume_name is required for external volume permission check") logger.warning("volume_name is required for external volume permission check")
return False return False
try: try:
# 检查External Volume权限 # 检查External Volume权限
permissions = self._get_external_volume_permissions(self._volume_name) permissions = self._get_external_volume_permissions(self._volume_name)
# External Volume权限映射根据操作类型确定所需权限 # External Volume权限映射根据操作类型确定所需权限
required_permissions = set() required_permissions = set()
if operation in [VolumePermission.READ, VolumePermission.LIST]: if operation in [VolumePermission.READ, VolumePermission.LIST]:
required_permissions.add("read") required_permissions.add("read")
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
required_permissions.add("write") required_permissions.add("write")
# 检查是否有所需的所有权限 # 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions) has_permission = required_permissions.issubset(permissions)
logger.debug(f"External Volume permission check for {self._volume_name}, operation {operation.name}: " logger.debug(f"External Volume permission check for {self._volume_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}") f"required={required_permissions}, has={permissions}, granted={has_permission}")
# 如果权限检查失败,尝试备选验证 # 如果权限检查失败,尝试备选验证
if not has_permission: if not has_permission:
logger.info(f"Direct permission check failed for {self._volume_name}, trying fallback verification") logger.info(f"Direct permission check failed for {self._volume_name}, trying fallback verification")
# 备选验证尝试列出Volume来验证基本访问权限 # 备选验证尝试列出Volume来验证基本访问权限
try: try:
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
@ -193,43 +198,43 @@ class VolumePermissionManager:
return True return True
except Exception as fallback_e: except Exception as fallback_e:
logger.warning(f"Fallback verification failed for {self._volume_name}: {fallback_e}") logger.warning(f"Fallback verification failed for {self._volume_name}: {fallback_e}")
return has_permission return has_permission
except Exception as e: except Exception as e:
logger.error(f"External volume permission check failed for {self._volume_name}: {e}") logger.error(f"External volume permission check failed for {self._volume_name}: {e}")
logger.info(f"External Volume permission check failed, but permission checking is disabled in this version") logger.info(f"External Volume permission check failed, but permission checking is disabled in this version")
return False return False
def _get_table_permissions(self, table_name: str) -> Set[str]: def _get_table_permissions(self, table_name: str) -> Set[str]:
"""获取用户对指定表的权限 """获取用户对指定表的权限
Args: Args:
table_name: 表名 table_name: 表名
Returns: Returns:
用户对该表的权限集合 用户对该表的权限集合
""" """
cache_key = f"table:{table_name}" cache_key = f"table:{table_name}"
if cache_key in self._permission_cache: if cache_key in self._permission_cache:
return self._permission_cache[cache_key] return self._permission_cache[cache_key]
permissions = set() permissions = set()
try: try:
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限 # 使用正确的ClickZetta语法检查当前用户权限
cursor.execute("SHOW GRANTS") cursor.execute("SHOW GRANTS")
grants = cursor.fetchall() grants = cursor.fetchall()
# 解析权限结果,查找对该表的权限 # 解析权限结果,查找对该表的权限
for grant in grants: for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
privilege = grant[0].upper() privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else "" object_type = grant[1].upper() if len(grant) > 1 else ""
object_name = grant[2] if len(grant) > 2 else "" object_name = grant[2] if len(grant) > 2 else ""
# 检查是否是对该表的权限 # 检查是否是对该表的权限
if object_type == "TABLE" and object_name == table_name: if object_type == "TABLE" and object_name == table_name:
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
@ -244,7 +249,7 @@ class VolumePermissionManager:
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else: else:
permissions.add(privilege) permissions.add(privilege)
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
if not permissions: if not permissions:
try: try:
@ -252,21 +257,21 @@ class VolumePermissionManager:
permissions.add("SELECT") permissions.add("SELECT")
except Exception: except Exception:
logger.debug(f"Cannot query table {table_name}, no SELECT permission") logger.debug(f"Cannot query table {table_name}, no SELECT permission")
except Exception as e: except Exception as e:
logger.warning(f"Could not check table permissions for {table_name}: {e}") logger.warning(f"Could not check table permissions for {table_name}: {e}")
# 安全默认:权限检查失败时拒绝访问 # 安全默认:权限检查失败时拒绝访问
pass pass
# 缓存权限信息 # 缓存权限信息
self._permission_cache[cache_key] = permissions self._permission_cache[cache_key] = permissions
return permissions return permissions
def _get_current_username(self) -> str: def _get_current_username(self) -> str:
"""获取当前用户名""" """获取当前用户名"""
if self._current_username: if self._current_username:
return self._current_username return self._current_username
try: try:
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
cursor.execute("SELECT CURRENT_USER()") cursor.execute("SELECT CURRENT_USER()")
@ -276,73 +281,74 @@ class VolumePermissionManager:
return self._current_username return self._current_username
except Exception as e: except Exception as e:
logger.error(f"Failed to get current username: {e}") logger.error(f"Failed to get current username: {e}")
return "unknown" return "unknown"
def _get_user_permissions(self, username: str) -> Set[str]: def _get_user_permissions(self, username: str) -> Set[str]:
"""获取用户的基本权限集合""" """获取用户的基本权限集合"""
cache_key = f"user_permissions:{username}" cache_key = f"user_permissions:{username}"
if cache_key in self._permission_cache: if cache_key in self._permission_cache:
return self._permission_cache[cache_key] return self._permission_cache[cache_key]
permissions = set() permissions = set()
try: try:
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限 # 使用正确的ClickZetta语法检查当前用户权限
cursor.execute("SHOW GRANTS") cursor.execute("SHOW GRANTS")
grants = cursor.fetchall() grants = cursor.fetchall()
# 解析权限结果,查找用户的基本权限 # 解析权限结果,查找用户的基本权限
for grant in grants: for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
privilege = grant[0].upper() privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else "" object_type = grant[1].upper() if len(grant) > 1 else ""
# 收集所有相关权限 # 收集所有相关权限
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL": if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else: else:
permissions.add(privilege) permissions.add(privilege)
except Exception as e: except Exception as e:
logger.warning(f"Could not check user permissions for {username}: {e}") logger.warning(f"Could not check user permissions for {username}: {e}")
# 安全默认:权限检查失败时拒绝访问 # 安全默认:权限检查失败时拒绝访问
pass pass
# 缓存权限信息 # 缓存权限信息
self._permission_cache[cache_key] = permissions self._permission_cache[cache_key] = permissions
return permissions return permissions
def _get_external_volume_permissions(self, volume_name: str) -> Set[str]: def _get_external_volume_permissions(self, volume_name: str) -> Set[str]:
"""获取用户对指定External Volume的权限 """获取用户对指定External Volume的权限
Args: Args:
volume_name: External Volume名称 volume_name: External Volume名称
Returns: Returns:
用户对该Volume的权限集合 用户对该Volume的权限集合
""" """
cache_key = f"external_volume:{volume_name}" cache_key = f"external_volume:{volume_name}"
if cache_key in self._permission_cache: if cache_key in self._permission_cache:
return self._permission_cache[cache_key] return self._permission_cache[cache_key]
permissions = set() permissions = set()
try: try:
with self._connection.cursor() as cursor: with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查Volume权限 # 使用正确的ClickZetta语法检查Volume权限
logger.info(f"Checking permissions for volume: {volume_name}") logger.info(f"Checking permissions for volume: {volume_name}")
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
grants = cursor.fetchall() grants = cursor.fetchall()
logger.info(f"Raw grants result for {volume_name}: {grants}") logger.info(f"Raw grants result for {volume_name}: {grants}")
# 解析权限结果 # 解析权限结果
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, grantee_name, grantor_name, grant_option, granted_time) # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
# grantee_name, grantor_name, grant_option, granted_time)
for grant in grants: for grant in grants:
logger.info(f"Processing grant: {grant}") logger.info(f"Processing grant: {grant}")
if len(grant) >= 5: if len(grant) >= 5:
@ -350,15 +356,19 @@ class VolumePermissionManager:
privilege = grant[1].upper() privilege = grant[1].upper()
granted_on = grant[3] granted_on = grant[3]
object_name = grant[4] object_name = grant[4]
logger.info(f"Grant details - type: {granted_type}, privilege: {privilege}, granted_on: {granted_on}, object_name: {object_name}") logger.info(
f"Grant details - type: {granted_type}, privilege: {privilege}, "
f"granted_on: {granted_on}, object_name: {object_name}"
)
# 检查是否是对该Volume的权限或者是层级权限 # 检查是否是对该Volume的权限或者是层级权限
if (granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)) or \ if ((granted_type == "PRIVILEGE" and granted_on == "VOLUME" and
(granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): object_name.endswith(volume_name)) or
(granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME")):
logger.info(f"Matching grant found for {volume_name}") logger.info(f"Matching grant found for {volume_name}")
if "READ" in privilege: if "READ" in privilege:
permissions.add("read") permissions.add("read")
logger.info(f"Added READ permission for {volume_name}") logger.info(f"Added READ permission for {volume_name}")
@ -371,9 +381,9 @@ class VolumePermissionManager:
if privilege == "ALL": if privilege == "ALL":
permissions.update(["read", "write", "alter"]) permissions.update(["read", "write", "alter"])
logger.info(f"Added ALL permissions for {volume_name}") logger.info(f"Added ALL permissions for {volume_name}")
logger.info(f"Final permissions for {volume_name}: {permissions}") logger.info(f"Final permissions for {volume_name}: {permissions}")
# 如果没有找到明确的权限尝试查看Volume列表来验证基本权限 # 如果没有找到明确的权限尝试查看Volume列表来验证基本权限
if not permissions: if not permissions:
try: try:
@ -386,7 +396,7 @@ class VolumePermissionManager:
break break
except Exception: except Exception:
logger.debug(f"Cannot access volume {volume_name}, no basic permission") logger.debug(f"Cannot access volume {volume_name}, no basic permission")
except Exception as e: except Exception as e:
logger.warning(f"Could not check external volume permissions for {volume_name}: {e}") logger.warning(f"Could not check external volume permissions for {volume_name}: {e}")
# 在权限检查失败时尝试基本的Volume访问验证 # 在权限检查失败时尝试基本的Volume访问验证
@ -404,102 +414,102 @@ class VolumePermissionManager:
logger.warning(f"Basic volume access check failed for {volume_name}: {basic_e}") logger.warning(f"Basic volume access check failed for {volume_name}: {basic_e}")
# 最后的备选方案:假设有基本权限 # 最后的备选方案:假设有基本权限
permissions.add("read") permissions.add("read")
# 缓存权限信息 # 缓存权限信息
self._permission_cache[cache_key] = permissions self._permission_cache[cache_key] = permissions
return permissions return permissions
def clear_permission_cache(self): def clear_permission_cache(self):
"""清空权限缓存""" """清空权限缓存"""
self._permission_cache.clear() self._permission_cache.clear()
logger.debug("Permission cache cleared") logger.debug("Permission cache cleared")
def get_permission_summary(self, dataset_id: Optional[str] = None) -> Dict[str, bool]: def get_permission_summary(self, dataset_id: Optional[str] = None) -> Dict[str, bool]:
"""获取权限摘要 """获取权限摘要
Args: Args:
dataset_id: 数据集ID (用于table volume) dataset_id: 数据集ID (用于table volume)
Returns: Returns:
权限摘要字典 权限摘要字典
""" """
summary = {} summary = {}
for operation in VolumePermission: for operation in VolumePermission:
summary[operation.name.lower()] = self.check_permission(operation, dataset_id) summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
return summary return summary
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
"""检查文件路径的权限继承 """检查文件路径的权限继承
Args: Args:
file_path: 文件路径 file_path: 文件路径
operation: 要执行的操作 operation: 要执行的操作
Returns: Returns:
True if user has permission, False otherwise True if user has permission, False otherwise
""" """
try: try:
# 解析文件路径 # 解析文件路径
path_parts = file_path.strip("/").split("/") path_parts = file_path.strip("/").split("/")
if not path_parts: if not path_parts:
logger.warning("Invalid file path for permission inheritance check") logger.warning("Invalid file path for permission inheritance check")
return False return False
# 对于Table Volume第一层是dataset_id # 对于Table Volume第一层是dataset_id
if self._volume_type == "table": if self._volume_type == "table":
if len(path_parts) < 1: if len(path_parts) < 1:
return False return False
dataset_id = path_parts[0] dataset_id = path_parts[0]
# 检查对dataset的权限 # 检查对dataset的权限
has_dataset_permission = self.check_permission(operation, dataset_id) has_dataset_permission = self.check_permission(operation, dataset_id)
if not has_dataset_permission: if not has_dataset_permission:
logger.debug(f"Permission denied for dataset {dataset_id}") logger.debug(f"Permission denied for dataset {dataset_id}")
return False return False
# 检查路径遍历攻击 # 检查路径遍历攻击
if self._contains_path_traversal(file_path): if self._contains_path_traversal(file_path):
logger.warning(f"Path traversal attack detected: {file_path}") logger.warning(f"Path traversal attack detected: {file_path}")
return False return False
# 检查是否访问敏感目录 # 检查是否访问敏感目录
if self._is_sensitive_path(file_path): if self._is_sensitive_path(file_path):
logger.warning(f"Access to sensitive path denied: {file_path}") logger.warning(f"Access to sensitive path denied: {file_path}")
return False return False
logger.debug(f"Permission inherited for path {file_path}") logger.debug(f"Permission inherited for path {file_path}")
return True return True
elif self._volume_type == "user": elif self._volume_type == "user":
# User Volume的权限继承 # User Volume的权限继承
current_user = self._get_current_username() current_user = self._get_current_username()
# 检查是否试图访问其他用户的目录 # 检查是否试图访问其他用户的目录
if len(path_parts) > 1 and path_parts[0] != current_user: if len(path_parts) > 1 and path_parts[0] != current_user:
logger.warning(f"User {current_user} attempted to access {path_parts[0]}'s directory") logger.warning(f"User {current_user} attempted to access {path_parts[0]}'s directory")
return False return False
# 检查基本权限 # 检查基本权限
return self.check_permission(operation) return self.check_permission(operation)
elif self._volume_type == "external": elif self._volume_type == "external":
# External Volume的权限继承 # External Volume的权限继承
# 检查对External Volume的权限 # 检查对External Volume的权限
return self.check_permission(operation) return self.check_permission(operation)
else: else:
logger.warning(f"Unknown volume type for permission inheritance: {self._volume_type}") logger.warning(f"Unknown volume type for permission inheritance: {self._volume_type}")
return False return False
except Exception as e: except Exception as e:
logger.error(f"Permission inheritance check failed: {e}") logger.error(f"Permission inheritance check failed: {e}")
return False return False
def _contains_path_traversal(self, file_path: str) -> bool: def _contains_path_traversal(self, file_path: str) -> bool:
"""检查路径是否包含路径遍历攻击""" """检查路径是否包含路径遍历攻击"""
# 检查常见的路径遍历模式 # 检查常见的路径遍历模式
@ -509,23 +519,23 @@ class VolumePermissionManager:
"%2e%2e%2f", "%2e%2e%5c", "%2e%2e%2f", "%2e%2e%5c",
"....//", "....\\\\", "....//", "....\\\\",
] ]
file_path_lower = file_path.lower() file_path_lower = file_path.lower()
for pattern in traversal_patterns: for pattern in traversal_patterns:
if pattern in file_path_lower: if pattern in file_path_lower:
return True return True
# 检查绝对路径 # 检查绝对路径
if file_path.startswith("/") or file_path.startswith("\\"): if file_path.startswith("/") or file_path.startswith("\\"):
return True return True
# 检查Windows驱动器路径 # 检查Windows驱动器路径
if len(file_path) >= 2 and file_path[1] == ":": if len(file_path) >= 2 and file_path[1] == ":":
return True return True
return False return False
def _is_sensitive_path(self, file_path: str) -> bool: def _is_sensitive_path(self, file_path: str) -> bool:
"""检查路径是否为敏感路径""" """检查路径是否为敏感路径"""
sensitive_patterns = [ sensitive_patterns = [
@ -533,22 +543,22 @@ class VolumePermissionManager:
"private", "key", "certificate", "cert", "ssl", "private", "key", "certificate", "cert", "ssl",
"database", "backup", "dump", "log", "tmp" "database", "backup", "dump", "log", "tmp"
] ]
file_path_lower = file_path.lower() file_path_lower = file_path.lower()
for pattern in sensitive_patterns: for pattern in sensitive_patterns:
if pattern in file_path_lower: if pattern in file_path_lower:
return True return True
return False return False
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
"""验证操作权限 """验证操作权限
Args: Args:
operation: 操作名称 (save|load|exists|delete|scan) operation: 操作名称 (save|load|exists|delete|scan)
dataset_id: 数据集ID dataset_id: 数据集ID
Returns: Returns:
True if operation is allowed, False otherwise True if operation is allowed, False otherwise
""" """
@ -562,18 +572,18 @@ class VolumePermissionManager:
"delete": VolumePermission.DELETE, "delete": VolumePermission.DELETE,
"scan": VolumePermission.LIST, "scan": VolumePermission.LIST,
} }
if operation not in operation_mapping: if operation not in operation_mapping:
logger.warning(f"Unknown operation: {operation}") logger.warning(f"Unknown operation: {operation}")
return False return False
volume_permission = operation_mapping[operation] volume_permission = operation_mapping[operation]
return self.check_permission(volume_permission, dataset_id) return self.check_permission(volume_permission, dataset_id)
class VolumePermissionError(Exception): class VolumePermissionError(Exception):
"""Volume权限错误异常""" """Volume权限错误异常"""
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
self.operation = operation self.operation = operation
self.volume_type = volume_type self.volume_type = volume_type
@ -581,16 +591,16 @@ class VolumePermissionError(Exception):
super().__init__(message) super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager, def check_volume_permission(permission_manager: VolumePermissionManager,
operation: str, operation: str,
dataset_id: Optional[str] = None) -> None: dataset_id: Optional[str] = None) -> None:
"""权限检查装饰器函数 """权限检查装饰器函数
Args: Args:
permission_manager: 权限管理器 permission_manager: 权限管理器
operation: 操作名称 operation: 操作名称
dataset_id: 数据集ID dataset_id: 数据集ID
Raises: Raises:
VolumePermissionError: 如果没有权限 VolumePermissionError: 如果没有权限
""" """
@ -598,10 +608,10 @@ def check_volume_permission(permission_manager: VolumePermissionManager,
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
if dataset_id: if dataset_id:
error_message += f" (dataset: {dataset_id})" error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError( raise VolumePermissionError(
error_message, error_message,
operation=operation, operation=operation,
volume_type=permission_manager._volume_type, volume_type=permission_manager._volume_type,
dataset_id=dataset_id dataset_id=dataset_id
) )

@ -177,4 +177,4 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -234,4 +234,4 @@ class TestClickzettaVector(AbstractVectorTest):
# Clean up # Clean up
vector_store.delete_by_metadata_field("lang", "chinese") vector_store.delete_by_metadata_field("lang", "chinese")
vector_store.delete_by_metadata_field("lang", "english") vector_store.delete_by_metadata_field("lang", "english")

@ -162,4 +162,4 @@ def main():
return 1 return 1
if __name__ == "__main__": if __name__ == "__main__":
exit(main()) exit(main())

@ -87,11 +87,12 @@ x-shared-env: &shared-api-worker-env
WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*}
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
STORAGE_TYPE: ${STORAGE_TYPE:-opendal} STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user} CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user}
CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-} CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-}
CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_} CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
S3_ENDPOINT: ${S3_ENDPOINT:-} S3_ENDPOINT: ${S3_ENDPOINT:-}
S3_REGION: ${S3_REGION:-us-east-1} S3_REGION: ${S3_REGION:-us-east-1}
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}

Loading…
Cancel
Save