fix: resolve CI linting issues and add missing newlines

- Fix all line length issues (120 character limit)
- Remove all trailing whitespace
- Add missing newlines at end of files
- Add CLICKZETTA_VOLUME_DIFY_PREFIX environment variable to docker-compose.yaml
- Ensure proper code formatting for all ClickZetta files

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent f3b1bdc04f
commit f57fa13f1b

@ -62,4 +62,4 @@ class ClickZettaVolumeStorageConfig(BaseSettings):
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
description="Directory prefix for User Volume to organize Dify files",
default="dify_km",
)
)

@ -66,4 +66,5 @@ class ClickzettaConfig(BaseModel):
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
description="Distance function for vector similarity: l2_distance or cosine_distance",
default="cosine_distance",
)
)

@ -1 +1 @@
# Clickzetta Vector Database Integration for Dify
# Clickzetta Vector Database Integration for Dify

@ -68,7 +68,7 @@ class ClickzettaVector(BaseVector):
"""
Clickzetta vector storage implementation.
"""
# Class-level write queue and lock for serializing writes
_write_queue: Optional[queue.Queue] = None
_write_thread: Optional[threading.Thread] = None
@ -94,13 +94,13 @@ class ClickzettaVector(BaseVector):
vcluster=self._config.vcluster,
schema=self._config.schema_name
)
# Set session parameters for better string handling
with self._connection.cursor() as cursor:
# Use quote mode for string literal escaping to handle quotes better
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling")
@classmethod
def _init_write_queue(cls):
"""Initialize the write queue and worker thread."""
@ -110,7 +110,7 @@ class ClickzettaVector(BaseVector):
cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True)
cls._write_thread.start()
logger.info("Started Clickzetta write worker thread")
@classmethod
def _write_worker(cls):
"""Worker thread that processes write tasks sequentially."""
@ -120,7 +120,7 @@ class ClickzettaVector(BaseVector):
task = cls._write_queue.get(timeout=1)
if task is None: # Shutdown signal
break
# Execute the write task
func, args, kwargs, result_queue = task
try:
@ -135,15 +135,15 @@ class ClickzettaVector(BaseVector):
continue
except Exception as e:
logger.exception("Write worker error")
def _execute_write(self, func, *args, **kwargs):
"""Execute a write operation through the queue."""
if ClickzettaVector._write_queue is None:
raise RuntimeError("Write queue not initialized")
result_queue = queue.Queue()
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
# Wait for result
success, result = result_queue.get()
if not success:
@ -171,18 +171,18 @@ class ClickzettaVector(BaseVector):
"""Create the collection and add initial documents."""
# Execute table creation through write queue to avoid concurrent conflicts
self._execute_write(self._create_table_and_indexes, embeddings)
# Add initial texts
if texts:
self.add_texts(texts, embeddings, **kwargs)
def _create_table_and_indexes(self, embeddings: list[list[float]]):
"""Create table and indexes (executed in write worker thread)."""
# Check if table already exists to avoid unnecessary index creation
if self._table_exists():
logger.info(f"Table {self._config.schema_name}.{self._table_name} already exists, skipping creation")
return
# Create table with vector and metadata columns
dimension = len(embeddings[0]) if embeddings else 768
@ -191,7 +191,8 @@ class ClickzettaVector(BaseVector):
id STRING NOT NULL COMMENT 'Unique document identifier',
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT 'High-dimensional embedding vector for semantic similarity search',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
'High-dimensional embedding vector for semantic similarity search',
PRIMARY KEY (id)
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
"""
@ -211,7 +212,7 @@ class ClickzettaVector(BaseVector):
"""Create HNSW vector index for similarity search."""
# Use a fixed index name based on table and column name
index_name = f"idx_{self._table_name}_vector"
# First check if an index already exists on this column
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
@ -223,7 +224,7 @@ class ClickzettaVector(BaseVector):
return
except Exception as e:
logger.warning(f"Failed to check existing indexes: {e}")
index_sql = f"""
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
@ -239,8 +240,8 @@ class ClickzettaVector(BaseVector):
logger.info(f"Created vector index: {index_name}")
except Exception as e:
error_msg = str(e).lower()
if ("already exists" in error_msg or
"already has index" in error_msg or
if ("already exists" in error_msg or
"already has index" in error_msg or
"with the same type" in error_msg):
logger.info(f"Vector index already exists: {e}")
else:
@ -251,7 +252,7 @@ class ClickzettaVector(BaseVector):
"""Create inverted index for full-text search."""
# Use a fixed index name based on table name to avoid duplicates
index_name = f"idx_{self._table_name}_text"
# Check if an inverted index already exists on this column
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
@ -259,14 +260,14 @@ class ClickzettaVector(BaseVector):
for idx in existing_indexes:
idx_str = str(idx).lower()
# More precise check: look for inverted index specifically on the content column
if ("inverted" in idx_str and
if ("inverted" in idx_str and
Field.CONTENT_KEY.value.lower() in idx_str and
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
logger.info(f"Inverted index already exists on column {Field.CONTENT_KEY.value}: {idx}")
return
except Exception as e:
logger.warning(f"Failed to check existing indexes: {e}")
index_sql = f"""
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
@ -281,8 +282,8 @@ class ClickzettaVector(BaseVector):
except Exception as e:
error_msg = str(e).lower()
# Handle ClickZetta specific error messages
if (("already exists" in error_msg or
"already has index" in error_msg or
if (("already exists" in error_msg or
"already has index" in error_msg or
"with the same type" in error_msg or
"cannot create inverted index" in error_msg) and
"already has index" in error_msg):
@ -313,44 +314,44 @@ class ClickzettaVector(BaseVector):
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
# Execute batch insert through write queue
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
batch_index: int, batch_size: int, total_batches: int):
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
if not batch_docs or not batch_embeddings:
logger.warning("Empty batch provided, skipping insertion")
return
if len(batch_docs) != len(batch_embeddings):
logger.error(f"Mismatch between docs ({len(batch_docs)}) and embeddings ({len(batch_embeddings)})")
return
# Prepare data for parameterized insertion
data_rows = []
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if doc.metadata else {}
if not isinstance(metadata, dict):
metadata = {}
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
# Fast path for JSON serialization
try:
metadata_json = json.dumps(metadata, ensure_ascii=True)
except (TypeError, ValueError):
logger.warning("JSON serialization failed, using empty dict")
metadata_json = "{}"
content = doc.page_content or ""
# According to ClickZetta docs, vector should be formatted as array string
# According to ClickZetta docs, vector should be formatted as array string
# for external systems: '[1.0, 2.0, 3.0]'
vector_str = '[' + ','.join(map(str, embedding)) + ']'
data_rows.append([doc_id, content, metadata_json, vector_str])
@ -359,17 +360,22 @@ class ClickzettaVector(BaseVector):
if not data_rows:
logger.warning(f"No valid documents to insert in batch {batch_index // batch_size + 1}/{total_batches}")
return
# Use parameterized INSERT with executemany for better performance and security
# Cast JSON and VECTOR in SQL, pass raw data as parameters
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
insert_sql = f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
insert_sql = (
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
)
with self._connection.cursor() as cursor:
try:
cursor.executemany(insert_sql, data_rows)
logger.info(f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)")
logger.info(
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
)
except Exception as e:
logger.exception(f"Parameterized SQL execution failed for {len(data_rows)} documents: {e}")
logger.exception(f"SQL template: {insert_sql}")
@ -399,14 +405,14 @@ class ClickzettaVector(BaseVector):
# Execute delete through write queue
self._execute_write(self._delete_by_ids_impl, ids)
def _delete_by_ids_impl(self, ids: list[str]) -> None:
"""Implementation of delete by IDs (executed in write worker thread)."""
safe_ids = [self._safe_doc_id(id) for id in ids]
# Create properly escaped string literals for SQL
id_list = ",".join(f"'{id}'" for id in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
with self._connection.cursor() as cursor:
cursor.execute(sql)
@ -419,7 +425,7 @@ class ClickzettaVector(BaseVector):
# Execute delete through write queue
self._execute_write(self._delete_by_metadata_field_impl, key, value)
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread)."""
with self._connection.cursor() as cursor:
@ -435,7 +441,7 @@ class ClickzettaVector(BaseVector):
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 0.0)
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
@ -445,8 +451,10 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})")
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table
# Add distance threshold based on distance function
@ -489,11 +497,11 @@ class ClickzettaVector(BaseVector):
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
@ -504,14 +512,14 @@ class ClickzettaVector(BaseVector):
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Add score based on distance
if self._config.vector_distance_function == "cosine_distance":
metadata["score"] = 1 - (row[3] / 2)
@ -531,7 +539,7 @@ class ClickzettaVector(BaseVector):
top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
@ -541,8 +549,10 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})")
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table
# Use match_all function for full-text search
@ -572,11 +582,11 @@ class ClickzettaVector(BaseVector):
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
@ -587,14 +597,14 @@ class ClickzettaVector(BaseVector):
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Add a relevance score for full-text search
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata)
@ -610,7 +620,7 @@ class ClickzettaVector(BaseVector):
"""Fallback search using LIKE operator."""
top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
@ -620,8 +630,10 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})")
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
@ -646,11 +658,11 @@ class ClickzettaVector(BaseVector):
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
@ -661,14 +673,14 @@ class ClickzettaVector(BaseVector):
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
metadata["score"] = 0.5 # Lower score for LIKE search
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
@ -680,11 +692,11 @@ class ClickzettaVector(BaseVector):
with self._connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
def _format_vector_simple(self, vector: list[float]) -> str:
"""Simple vector formatting for SQL queries."""
return ','.join(map(str, vector))
def _safe_doc_id(self, doc_id: str) -> str:
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
if not doc_id:
@ -696,7 +708,7 @@ class ClickzettaVector(BaseVector):
if not safe_id: # If all characters were removed
return str(uuid.uuid4())
return safe_id[:255] # Limit length
class ClickzettaVectorFactory(AbstractVectorFactory):
@ -724,3 +736,4 @@ class ClickzettaVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
return ClickzettaVector(collection_name=collection_name, config=config)

@ -2,4 +2,4 @@
from .clickzetta_volume_storage import ClickZettaVolumeStorage
__all__ = ["ClickZettaVolumeStorage"]
__all__ = ["ClickZettaVolumeStorage"]

@ -526,4 +526,4 @@ class ClickZettaVolumeStorage(BaseStorage):
except Exception as e:
logger.error(f"Error scanning path {path}: {e}")
return []
return []

@ -508,4 +508,4 @@ class FileLifecycleManager:
except Exception as e:
logger.error(f"Permission check failed for {filename} operation {operation}: {e}")
# 安全默认:权限检查失败时拒绝访问
return False
return False

@ -22,10 +22,10 @@ class VolumePermission(Enum):
class VolumePermissionManager:
"""Volume权限管理器"""
def __init__(self, connection_or_config, volume_type: str = None, volume_name: Optional[str] = None):
"""初始化权限管理器
Args:
connection_or_config: ClickZetta连接对象或配置字典
volume_type: Volume类型 (user|table|external)
@ -52,22 +52,22 @@ class VolumePermissionManager:
self._connection = connection_or_config
self._volume_type = volume_type
self._volume_name = volume_name
if not self._connection:
raise ValueError("Valid connection or config is required")
if not self._volume_type:
raise ValueError("volume_type is required")
self._permission_cache: Dict[str, Set[str]] = {}
self._current_username = None # 将从连接中获取当前用户名
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
"""检查用户是否有执行特定操作的权限
Args:
operation: 要执行的操作类型
dataset_id: 数据集ID (用于table volume)
Returns:
True if user has permission, False otherwise
"""
@ -81,14 +81,14 @@ class VolumePermissionManager:
else:
logger.warning(f"Unknown volume type: {self._volume_type}")
return False
except Exception as e:
logger.error(f"Permission check failed: {e}")
return False
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
"""检查User Volume权限
User Volume权限规则:
- 用户对自己的User Volume有全部权限
- 只要用户能够连接到ClickZetta就默认具有User Volume的基本权限
@ -97,29 +97,34 @@ class VolumePermissionManager:
try:
# 获取当前用户名
current_user = self._get_current_username()
# 检查基本连接状态
with self._connection.cursor() as cursor:
# 简单的连接测试,如果能执行查询说明用户有基本权限
cursor.execute("SELECT 1")
result = cursor.fetchone()
if result:
logger.debug(f"User Volume permission check for {current_user}, operation {operation.name}: granted (basic connection verified)")
logger.debug(
f"User Volume permission check for {current_user}, operation {operation.name}: "
f"granted (basic connection verified)"
)
return True
else:
logger.warning(f"User Volume permission check failed: cannot verify basic connection for {current_user}")
logger.warning(
f"User Volume permission check failed: cannot verify basic connection for {current_user}"
)
return False
except Exception as e:
logger.error(f"User Volume permission check failed: {e}")
# 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示
logger.info(f"User Volume permission check failed, but permission checking is disabled in this version")
return False
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
"""检查Table Volume权限
Table Volume权限规则:
- Table Volume权限继承对应表的权限
- SELECT权限 -> 可以READ/LIST文件
@ -128,29 +133,29 @@ class VolumePermissionManager:
if not dataset_id:
logger.warning("dataset_id is required for table volume permission check")
return False
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
try:
# 检查表权限
permissions = self._get_table_permissions(table_name)
required_permissions = set(operation.value.split(","))
# 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions)
logger.debug(f"Table Volume permission check for {table_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}")
return has_permission
except Exception as e:
logger.error(f"Table volume permission check failed for {table_name}: {e}")
return False
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
"""检查External Volume权限
External Volume权限规则:
- 尝试获取对External Volume的权限
- 如果权限检查失败进行备选验证
@ -159,29 +164,29 @@ class VolumePermissionManager:
if not self._volume_name:
logger.warning("volume_name is required for external volume permission check")
return False
try:
# 检查External Volume权限
permissions = self._get_external_volume_permissions(self._volume_name)
# External Volume权限映射根据操作类型确定所需权限
required_permissions = set()
if operation in [VolumePermission.READ, VolumePermission.LIST]:
required_permissions.add("read")
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
required_permissions.add("write")
# 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions)
logger.debug(f"External Volume permission check for {self._volume_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}")
# 如果权限检查失败,尝试备选验证
if not has_permission:
logger.info(f"Direct permission check failed for {self._volume_name}, trying fallback verification")
# 备选验证尝试列出Volume来验证基本访问权限
try:
with self._connection.cursor() as cursor:
@ -193,43 +198,43 @@ class VolumePermissionManager:
return True
except Exception as fallback_e:
logger.warning(f"Fallback verification failed for {self._volume_name}: {fallback_e}")
return has_permission
except Exception as e:
logger.error(f"External volume permission check failed for {self._volume_name}: {e}")
logger.info(f"External Volume permission check failed, but permission checking is disabled in this version")
return False
def _get_table_permissions(self, table_name: str) -> Set[str]:
"""获取用户对指定表的权限
Args:
table_name: 表名
Returns:
用户对该表的权限集合
"""
cache_key = f"table:{table_name}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
# 解析权限结果,查找对该表的权限
for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""
object_name = grant[2] if len(grant) > 2 else ""
# 检查是否是对该表的权限
if object_type == "TABLE" and object_name == table_name:
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
@ -244,7 +249,7 @@ class VolumePermissionManager:
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
if not permissions:
try:
@ -252,21 +257,21 @@ class VolumePermissionManager:
permissions.add("SELECT")
except Exception:
logger.debug(f"Cannot query table {table_name}, no SELECT permission")
except Exception as e:
logger.warning(f"Could not check table permissions for {table_name}: {e}")
# 安全默认:权限检查失败时拒绝访问
pass
# 缓存权限信息
self._permission_cache[cache_key] = permissions
return permissions
def _get_current_username(self) -> str:
"""获取当前用户名"""
if self._current_username:
return self._current_username
try:
with self._connection.cursor() as cursor:
cursor.execute("SELECT CURRENT_USER()")
@ -276,73 +281,74 @@ class VolumePermissionManager:
return self._current_username
except Exception as e:
logger.error(f"Failed to get current username: {e}")
return "unknown"
def _get_user_permissions(self, username: str) -> Set[str]:
"""获取用户的基本权限集合"""
cache_key = f"user_permissions:{username}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
# 解析权限结果,查找用户的基本权限
for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""
# 收集所有相关权限
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
except Exception as e:
logger.warning(f"Could not check user permissions for {username}: {e}")
# 安全默认:权限检查失败时拒绝访问
pass
# 缓存权限信息
self._permission_cache[cache_key] = permissions
return permissions
def _get_external_volume_permissions(self, volume_name: str) -> Set[str]:
"""获取用户对指定External Volume的权限
Args:
volume_name: External Volume名称
Returns:
用户对该Volume的权限集合
"""
cache_key = f"external_volume:{volume_name}"
if cache_key in self._permission_cache:
return self._permission_cache[cache_key]
permissions = set()
try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查Volume权限
logger.info(f"Checking permissions for volume: {volume_name}")
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
grants = cursor.fetchall()
logger.info(f"Raw grants result for {volume_name}: {grants}")
# 解析权限结果
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, grantee_name, grantor_name, grant_option, granted_time)
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
# grantee_name, grantor_name, grant_option, granted_time)
for grant in grants:
logger.info(f"Processing grant: {grant}")
if len(grant) >= 5:
@ -350,15 +356,19 @@ class VolumePermissionManager:
privilege = grant[1].upper()
granted_on = grant[3]
object_name = grant[4]
logger.info(f"Grant details - type: {granted_type}, privilege: {privilege}, granted_on: {granted_on}, object_name: {object_name}")
logger.info(
f"Grant details - type: {granted_type}, privilege: {privilege}, "
f"granted_on: {granted_on}, object_name: {object_name}"
)
# 检查是否是对该Volume的权限或者是层级权限
if (granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)) or \
(granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
if ((granted_type == "PRIVILEGE" and granted_on == "VOLUME" and
object_name.endswith(volume_name)) or
(granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME")):
logger.info(f"Matching grant found for {volume_name}")
if "READ" in privilege:
permissions.add("read")
logger.info(f"Added READ permission for {volume_name}")
@ -371,9 +381,9 @@ class VolumePermissionManager:
if privilege == "ALL":
permissions.update(["read", "write", "alter"])
logger.info(f"Added ALL permissions for {volume_name}")
logger.info(f"Final permissions for {volume_name}: {permissions}")
# 如果没有找到明确的权限尝试查看Volume列表来验证基本权限
if not permissions:
try:
@ -386,7 +396,7 @@ class VolumePermissionManager:
break
except Exception:
logger.debug(f"Cannot access volume {volume_name}, no basic permission")
except Exception as e:
logger.warning(f"Could not check external volume permissions for {volume_name}: {e}")
# 在权限检查失败时尝试基本的Volume访问验证
@ -404,102 +414,102 @@ class VolumePermissionManager:
logger.warning(f"Basic volume access check failed for {volume_name}: {basic_e}")
# 最后的备选方案:假设有基本权限
permissions.add("read")
# 缓存权限信息
self._permission_cache[cache_key] = permissions
return permissions
def clear_permission_cache(self):
"""清空权限缓存"""
self._permission_cache.clear()
logger.debug("Permission cache cleared")
def get_permission_summary(self, dataset_id: Optional[str] = None) -> Dict[str, bool]:
"""获取权限摘要
Args:
dataset_id: 数据集ID (用于table volume)
Returns:
权限摘要字典
"""
summary = {}
for operation in VolumePermission:
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
return summary
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
"""检查文件路径的权限继承
Args:
file_path: 文件路径
operation: 要执行的操作
Returns:
True if user has permission, False otherwise
"""
try:
# 解析文件路径
path_parts = file_path.strip("/").split("/")
if not path_parts:
logger.warning("Invalid file path for permission inheritance check")
return False
# 对于Table Volume第一层是dataset_id
if self._volume_type == "table":
if len(path_parts) < 1:
return False
dataset_id = path_parts[0]
# 检查对dataset的权限
has_dataset_permission = self.check_permission(operation, dataset_id)
if not has_dataset_permission:
logger.debug(f"Permission denied for dataset {dataset_id}")
return False
# 检查路径遍历攻击
if self._contains_path_traversal(file_path):
logger.warning(f"Path traversal attack detected: {file_path}")
return False
# 检查是否访问敏感目录
if self._is_sensitive_path(file_path):
logger.warning(f"Access to sensitive path denied: {file_path}")
return False
logger.debug(f"Permission inherited for path {file_path}")
return True
elif self._volume_type == "user":
# User Volume的权限继承
current_user = self._get_current_username()
# 检查是否试图访问其他用户的目录
if len(path_parts) > 1 and path_parts[0] != current_user:
logger.warning(f"User {current_user} attempted to access {path_parts[0]}'s directory")
return False
# 检查基本权限
return self.check_permission(operation)
elif self._volume_type == "external":
# External Volume的权限继承
# 检查对External Volume的权限
return self.check_permission(operation)
else:
logger.warning(f"Unknown volume type for permission inheritance: {self._volume_type}")
return False
except Exception as e:
logger.error(f"Permission inheritance check failed: {e}")
return False
def _contains_path_traversal(self, file_path: str) -> bool:
"""检查路径是否包含路径遍历攻击"""
# 检查常见的路径遍历模式
@ -509,23 +519,23 @@ class VolumePermissionManager:
"%2e%2e%2f", "%2e%2e%5c",
"....//", "....\\\\",
]
file_path_lower = file_path.lower()
for pattern in traversal_patterns:
if pattern in file_path_lower:
return True
# 检查绝对路径
if file_path.startswith("/") or file_path.startswith("\\"):
return True
# 检查Windows驱动器路径
if len(file_path) >= 2 and file_path[1] == ":":
return True
return False
def _is_sensitive_path(self, file_path: str) -> bool:
"""检查路径是否为敏感路径"""
sensitive_patterns = [
@ -533,22 +543,22 @@ class VolumePermissionManager:
"private", "key", "certificate", "cert", "ssl",
"database", "backup", "dump", "log", "tmp"
]
file_path_lower = file_path.lower()
for pattern in sensitive_patterns:
if pattern in file_path_lower:
return True
return False
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
"""验证操作权限
Args:
operation: 操作名称 (save|load|exists|delete|scan)
dataset_id: 数据集ID
Returns:
True if operation is allowed, False otherwise
"""
@ -562,18 +572,18 @@ class VolumePermissionManager:
"delete": VolumePermission.DELETE,
"scan": VolumePermission.LIST,
}
if operation not in operation_mapping:
logger.warning(f"Unknown operation: {operation}")
return False
volume_permission = operation_mapping[operation]
return self.check_permission(volume_permission, dataset_id)
class VolumePermissionError(Exception):
"""Volume权限错误异常"""
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
self.operation = operation
self.volume_type = volume_type
@ -581,16 +591,16 @@ class VolumePermissionError(Exception):
super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager,
operation: str,
def check_volume_permission(permission_manager: VolumePermissionManager,
operation: str,
dataset_id: Optional[str] = None) -> None:
"""权限检查装饰器函数
Args:
permission_manager: 权限管理器
operation: 操作名称
dataset_id: 数据集ID
Raises:
VolumePermissionError: 如果没有权限
"""
@ -598,10 +608,10 @@ def check_volume_permission(permission_manager: VolumePermissionManager,
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
if dataset_id:
error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError(
error_message,
operation=operation,
volume_type=permission_manager._volume_type,
dataset_id=dataset_id
)
)

@ -177,4 +177,4 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
unittest.main()

@ -234,4 +234,4 @@ class TestClickzettaVector(AbstractVectorTest):
# Clean up
vector_store.delete_by_metadata_field("lang", "chinese")
vector_store.delete_by_metadata_field("lang", "english")
vector_store.delete_by_metadata_field("lang", "english")

@ -162,4 +162,4 @@ def main():
return 1
if __name__ == "__main__":
exit(main())
exit(main())

@ -87,11 +87,12 @@ x-shared-env: &shared-api-worker-env
WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*}
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user}
CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-}
CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km}
S3_ENDPOINT: ${S3_ENDPOINT:-}
S3_REGION: ${S3_REGION:-us-east-1}
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}

Loading…
Cancel
Save