diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py index f077373622..96eb6d3dd7 100644 --- a/api/configs/middleware/storage/clickzetta_volume_storage_config.py +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -62,4 +62,4 @@ class ClickZettaVolumeStorageConfig(BaseSettings): CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field( description="Directory prefix for User Volume to organize Dify files", default="dify_km", - ) \ No newline at end of file + ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py index a2822dbfee..b08df7a5b5 100644 --- a/api/configs/middleware/vdb/clickzetta_config.py +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -66,4 +66,5 @@ class ClickzettaConfig(BaseModel): CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( description="Distance function for vector similarity: l2_distance or cosine_distance", default="cosine_distance", - ) \ No newline at end of file + ) + diff --git a/api/core/rag/datasource/vdb/clickzetta/__init__.py b/api/core/rag/datasource/vdb/clickzetta/__init__.py index fecadb863a..9d41c5a57d 100644 --- a/api/core/rag/datasource/vdb/clickzetta/__init__.py +++ b/api/core/rag/datasource/vdb/clickzetta/__init__.py @@ -1 +1 @@ -# Clickzetta Vector Database Integration for Dify \ No newline at end of file +# Clickzetta Vector Database Integration for Dify diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 9e850b2646..181fe56f98 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -68,7 +68,7 @@ class ClickzettaVector(BaseVector): """ Clickzetta vector storage implementation. """ - + # Class-level write queue and lock for serializing writes _write_queue: Optional[queue.Queue] = None _write_thread: Optional[threading.Thread] = None @@ -94,13 +94,13 @@ class ClickzettaVector(BaseVector): vcluster=self._config.vcluster, schema=self._config.schema_name ) - + # Set session parameters for better string handling with self._connection.cursor() as cursor: # Use quote mode for string literal escaping to handle quotes better cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") logger.info("Set string literal escape mode to 'quote' for better quote handling") - + @classmethod def _init_write_queue(cls): """Initialize the write queue and worker thread.""" @@ -110,7 +110,7 @@ class ClickzettaVector(BaseVector): cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True) cls._write_thread.start() logger.info("Started Clickzetta write worker thread") - + @classmethod def _write_worker(cls): """Worker thread that processes write tasks sequentially.""" @@ -120,7 +120,7 @@ class ClickzettaVector(BaseVector): task = cls._write_queue.get(timeout=1) if task is None: # Shutdown signal break - + # Execute the write task func, args, kwargs, result_queue = task try: @@ -135,15 +135,15 @@ class ClickzettaVector(BaseVector): continue except Exception as e: logger.exception("Write worker error") - + def _execute_write(self, func, *args, **kwargs): """Execute a write operation through the queue.""" if ClickzettaVector._write_queue is None: raise RuntimeError("Write queue not initialized") - + result_queue = queue.Queue() ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) - + # Wait for result success, result = result_queue.get() if not success: @@ -171,18 +171,18 @@ class ClickzettaVector(BaseVector): """Create the collection and add initial documents.""" # Execute table creation through write queue to avoid concurrent conflicts self._execute_write(self._create_table_and_indexes, embeddings) - + # Add initial texts if texts: self.add_texts(texts, embeddings, **kwargs) - + def _create_table_and_indexes(self, embeddings: list[list[float]]): """Create table and indexes (executed in write worker thread).""" # Check if table already exists to avoid unnecessary index creation if self._table_exists(): logger.info(f"Table {self._config.schema_name}.{self._table_name} already exists, skipping creation") return - + # Create table with vector and metadata columns dimension = len(embeddings[0]) if embeddings else 768 @@ -191,7 +191,8 @@ class ClickzettaVector(BaseVector): id STRING NOT NULL COMMENT 'Unique document identifier', {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', - {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT 'High-dimensional embedding vector for semantic similarity search', + {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT + 'High-dimensional embedding vector for semantic similarity search', PRIMARY KEY (id) ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' """ @@ -211,7 +212,7 @@ class ClickzettaVector(BaseVector): """Create HNSW vector index for similarity search.""" # Use a fixed index name based on table and column name index_name = f"idx_{self._table_name}_vector" - + # First check if an index already exists on this column try: cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") @@ -223,7 +224,7 @@ class ClickzettaVector(BaseVector): return except Exception as e: logger.warning(f"Failed to check existing indexes: {e}") - + index_sql = f""" CREATE VECTOR INDEX IF NOT EXISTS {index_name} ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) @@ -239,8 +240,8 @@ class ClickzettaVector(BaseVector): logger.info(f"Created vector index: {index_name}") except Exception as e: error_msg = str(e).lower() - if ("already exists" in error_msg or - "already has index" in error_msg or + if ("already exists" in error_msg or + "already has index" in error_msg or "with the same type" in error_msg): logger.info(f"Vector index already exists: {e}") else: @@ -251,7 +252,7 @@ class ClickzettaVector(BaseVector): """Create inverted index for full-text search.""" # Use a fixed index name based on table name to avoid duplicates index_name = f"idx_{self._table_name}_text" - + # Check if an inverted index already exists on this column try: cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") @@ -259,14 +260,14 @@ class ClickzettaVector(BaseVector): for idx in existing_indexes: idx_str = str(idx).lower() # More precise check: look for inverted index specifically on the content column - if ("inverted" in idx_str and + if ("inverted" in idx_str and Field.CONTENT_KEY.value.lower() in idx_str and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)): logger.info(f"Inverted index already exists on column {Field.CONTENT_KEY.value}: {idx}") return except Exception as e: logger.warning(f"Failed to check existing indexes: {e}") - + index_sql = f""" CREATE INVERTED INDEX IF NOT EXISTS {index_name} ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) @@ -281,8 +282,8 @@ class ClickzettaVector(BaseVector): except Exception as e: error_msg = str(e).lower() # Handle ClickZetta specific error messages - if (("already exists" in error_msg or - "already has index" in error_msg or + if (("already exists" in error_msg or + "already has index" in error_msg or "with the same type" in error_msg or "cannot create inverted index" in error_msg) and "already has index" in error_msg): @@ -313,44 +314,44 @@ class ClickzettaVector(BaseVector): for i in range(0, len(documents), batch_size): batch_docs = documents[i:i + batch_size] batch_embeddings = embeddings[i:i + batch_size] - + # Execute batch insert through write queue self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) - - def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]], + + def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]], batch_index: int, batch_size: int, total_batches: int): """Insert a batch of documents using parameterized queries (executed in write worker thread).""" if not batch_docs or not batch_embeddings: logger.warning("Empty batch provided, skipping insertion") return - + if len(batch_docs) != len(batch_embeddings): logger.error(f"Mismatch between docs ({len(batch_docs)}) and embeddings ({len(batch_embeddings)})") return - + # Prepare data for parameterized insertion data_rows = [] vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 - + for doc, embedding in zip(batch_docs, batch_embeddings): # Optimized: minimal checks for common case, fallback for edge cases metadata = doc.metadata if doc.metadata else {} - + if not isinstance(metadata, dict): metadata = {} - + doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) - + # Fast path for JSON serialization try: metadata_json = json.dumps(metadata, ensure_ascii=True) except (TypeError, ValueError): logger.warning("JSON serialization failed, using empty dict") metadata_json = "{}" - + content = doc.page_content or "" - - # According to ClickZetta docs, vector should be formatted as array string + + # According to ClickZetta docs, vector should be formatted as array string # for external systems: '[1.0, 2.0, 3.0]' vector_str = '[' + ','.join(map(str, embedding)) + ']' data_rows.append([doc_id, content, metadata_json, vector_str]) @@ -359,17 +360,22 @@ class ClickzettaVector(BaseVector): if not data_rows: logger.warning(f"No valid documents to insert in batch {batch_index // batch_size + 1}/{total_batches}") return - + # Use parameterized INSERT with executemany for better performance and security # Cast JSON and VECTOR in SQL, pass raw data as parameters columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" - insert_sql = f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" - + insert_sql = ( + f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) " + f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" + ) + with self._connection.cursor() as cursor: try: cursor.executemany(insert_sql, data_rows) - logger.info(f"Inserted batch {batch_index // batch_size + 1}/{total_batches} " - f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)") + logger.info( + f"Inserted batch {batch_index // batch_size + 1}/{total_batches} " + f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)" + ) except Exception as e: logger.exception(f"Parameterized SQL execution failed for {len(data_rows)} documents: {e}") logger.exception(f"SQL template: {insert_sql}") @@ -399,14 +405,14 @@ class ClickzettaVector(BaseVector): # Execute delete through write queue self._execute_write(self._delete_by_ids_impl, ids) - + def _delete_by_ids_impl(self, ids: list[str]) -> None: """Implementation of delete by IDs (executed in write worker thread).""" safe_ids = [self._safe_doc_id(id) for id in ids] # Create properly escaped string literals for SQL id_list = ",".join(f"'{id}'" for id in safe_ids) sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" - + with self._connection.cursor() as cursor: cursor.execute(sql) @@ -419,7 +425,7 @@ class ClickzettaVector(BaseVector): # Execute delete through write queue self._execute_write(self._delete_by_metadata_field_impl, key, value) - + def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: """Implementation of delete by metadata field (executed in write worker thread).""" with self._connection.cursor() as cursor: @@ -435,7 +441,7 @@ class ClickzettaVector(BaseVector): top_k = kwargs.get("top_k", 10) score_threshold = kwargs.get("score_threshold", 0.0) document_ids_filter = kwargs.get("document_ids_filter") - + # Handle filter parameter from canvas (workflow) filter_param = kwargs.get("filter", {}) @@ -445,8 +451,10 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})") - + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + # No need for dataset_id filter since each dataset has its own table # Add distance threshold based on distance function @@ -489,11 +497,11 @@ class ClickzettaVector(BaseVector): try: if row[2]: metadata = json.loads(row[2]) - + # If result is a string, it's double-encoded JSON - parse again if isinstance(metadata, str): metadata = json.loads(metadata) - + if not isinstance(metadata, dict): metadata = {} else: @@ -504,14 +512,14 @@ class ClickzettaVector(BaseVector): import re doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - + # Ensure required fields are set metadata["doc_id"] = row[0] # segment id - + # Ensure document_id exists (critical for Dify's format_retrieval_documents) if "document_id" not in metadata: metadata["document_id"] = row[0] # fallback to segment id - + # Add score based on distance if self._config.vector_distance_function == "cosine_distance": metadata["score"] = 1 - (row[3] / 2) @@ -531,7 +539,7 @@ class ClickzettaVector(BaseVector): top_k = kwargs.get("top_k", 10) document_ids_filter = kwargs.get("document_ids_filter") - + # Handle filter parameter from canvas (workflow) filter_param = kwargs.get("filter", {}) @@ -541,8 +549,10 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})") - + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + # No need for dataset_id filter since each dataset has its own table # Use match_all function for full-text search @@ -572,11 +582,11 @@ class ClickzettaVector(BaseVector): try: if row[2]: metadata = json.loads(row[2]) - + # If result is a string, it's double-encoded JSON - parse again if isinstance(metadata, str): metadata = json.loads(metadata) - + if not isinstance(metadata, dict): metadata = {} else: @@ -587,14 +597,14 @@ class ClickzettaVector(BaseVector): import re doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - + # Ensure required fields are set metadata["doc_id"] = row[0] # segment id - + # Ensure document_id exists (critical for Dify's format_retrieval_documents) if "document_id" not in metadata: metadata["document_id"] = row[0] # fallback to segment id - + # Add a relevance score for full-text search metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores doc = Document(page_content=row[1], metadata=metadata) @@ -610,7 +620,7 @@ class ClickzettaVector(BaseVector): """Fallback search using LIKE operator.""" top_k = kwargs.get("top_k", 10) document_ids_filter = kwargs.get("document_ids_filter") - + # Handle filter parameter from canvas (workflow) filter_param = kwargs.get("filter", {}) @@ -620,8 +630,10 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append(f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})") - + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + # No need for dataset_id filter since each dataset has its own table # Use simple quote escaping for LIKE clause @@ -646,11 +658,11 @@ class ClickzettaVector(BaseVector): try: if row[2]: metadata = json.loads(row[2]) - + # If result is a string, it's double-encoded JSON - parse again if isinstance(metadata, str): metadata = json.loads(metadata) - + if not isinstance(metadata, dict): metadata = {} else: @@ -661,14 +673,14 @@ class ClickzettaVector(BaseVector): import re doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} - + # Ensure required fields are set metadata["doc_id"] = row[0] # segment id - + # Ensure document_id exists (critical for Dify's format_retrieval_documents) if "document_id" not in metadata: metadata["document_id"] = row[0] # fallback to segment id - + metadata["score"] = 0.5 # Lower score for LIKE search doc = Document(page_content=row[1], metadata=metadata) documents.append(doc) @@ -680,11 +692,11 @@ class ClickzettaVector(BaseVector): with self._connection.cursor() as cursor: cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") - + def _format_vector_simple(self, vector: list[float]) -> str: """Simple vector formatting for SQL queries.""" return ','.join(map(str, vector)) - + def _safe_doc_id(self, doc_id: str) -> str: """Ensure doc_id is safe for SQL and doesn't contain special characters.""" if not doc_id: @@ -696,7 +708,7 @@ class ClickzettaVector(BaseVector): if not safe_id: # If all characters were removed return str(uuid.uuid4()) return safe_id[:255] # Limit length - + class ClickzettaVectorFactory(AbstractVectorFactory): @@ -724,3 +736,4 @@ class ClickzettaVectorFactory(AbstractVectorFactory): collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() return ClickzettaVector(collection_name=collection_name, config=config) + diff --git a/api/extensions/storage/clickzetta_volume/__init__.py b/api/extensions/storage/clickzetta_volume/__init__.py index 6117e57e44..8a1588034b 100644 --- a/api/extensions/storage/clickzetta_volume/__init__.py +++ b/api/extensions/storage/clickzetta_volume/__init__.py @@ -2,4 +2,4 @@ from .clickzetta_volume_storage import ClickZettaVolumeStorage -__all__ = ["ClickZettaVolumeStorage"] \ No newline at end of file +__all__ = ["ClickZettaVolumeStorage"] diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index bd0c3ea1fc..150412a899 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -526,4 +526,4 @@ class ClickZettaVolumeStorage(BaseStorage): except Exception as e: logger.error(f"Error scanning path {path}: {e}") - return [] \ No newline at end of file + return [] diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index bb140dd139..9e36e97328 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -508,4 +508,4 @@ class FileLifecycleManager: except Exception as e: logger.error(f"Permission check failed for {filename} operation {operation}: {e}") # 安全默认:权限检查失败时拒绝访问 - return False \ No newline at end of file + return False diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 4b76c625c5..9d52b80b46 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -22,10 +22,10 @@ class VolumePermission(Enum): class VolumePermissionManager: """Volume权限管理器""" - + def __init__(self, connection_or_config, volume_type: str = None, volume_name: Optional[str] = None): """初始化权限管理器 - + Args: connection_or_config: ClickZetta连接对象或配置字典 volume_type: Volume类型 (user|table|external) @@ -52,22 +52,22 @@ class VolumePermissionManager: self._connection = connection_or_config self._volume_type = volume_type self._volume_name = volume_name - + if not self._connection: raise ValueError("Valid connection or config is required") if not self._volume_type: raise ValueError("volume_type is required") - + self._permission_cache: Dict[str, Set[str]] = {} self._current_username = None # 将从连接中获取当前用户名 - + def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: """检查用户是否有执行特定操作的权限 - + Args: operation: 要执行的操作类型 dataset_id: 数据集ID (用于table volume) - + Returns: True if user has permission, False otherwise """ @@ -81,14 +81,14 @@ class VolumePermissionManager: else: logger.warning(f"Unknown volume type: {self._volume_type}") return False - + except Exception as e: logger.error(f"Permission check failed: {e}") return False - + def _check_user_volume_permission(self, operation: VolumePermission) -> bool: """检查User Volume权限 - + User Volume权限规则: - 用户对自己的User Volume有全部权限 - 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限 @@ -97,29 +97,34 @@ class VolumePermissionManager: try: # 获取当前用户名 current_user = self._get_current_username() - + # 检查基本连接状态 with self._connection.cursor() as cursor: # 简单的连接测试,如果能执行查询说明用户有基本权限 cursor.execute("SELECT 1") result = cursor.fetchone() - + if result: - logger.debug(f"User Volume permission check for {current_user}, operation {operation.name}: granted (basic connection verified)") + logger.debug( + f"User Volume permission check for {current_user}, operation {operation.name}: " + f"granted (basic connection verified)" + ) return True else: - logger.warning(f"User Volume permission check failed: cannot verify basic connection for {current_user}") + logger.warning( + f"User Volume permission check failed: cannot verify basic connection for {current_user}" + ) return False - + except Exception as e: logger.error(f"User Volume permission check failed: {e}") # 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示 logger.info(f"User Volume permission check failed, but permission checking is disabled in this version") return False - + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: """检查Table Volume权限 - + Table Volume权限规则: - Table Volume权限继承对应表的权限 - SELECT权限 -> 可以READ/LIST文件 @@ -128,29 +133,29 @@ class VolumePermissionManager: if not dataset_id: logger.warning("dataset_id is required for table volume permission check") return False - + table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id - + try: # 检查表权限 permissions = self._get_table_permissions(table_name) required_permissions = set(operation.value.split(",")) - + # 检查是否有所需的所有权限 has_permission = required_permissions.issubset(permissions) - + logger.debug(f"Table Volume permission check for {table_name}, operation {operation.name}: " f"required={required_permissions}, has={permissions}, granted={has_permission}") - + return has_permission - + except Exception as e: logger.error(f"Table volume permission check failed for {table_name}: {e}") return False - + def _check_external_volume_permission(self, operation: VolumePermission) -> bool: """检查External Volume权限 - + External Volume权限规则: - 尝试获取对External Volume的权限 - 如果权限检查失败,进行备选验证 @@ -159,29 +164,29 @@ class VolumePermissionManager: if not self._volume_name: logger.warning("volume_name is required for external volume permission check") return False - + try: # 检查External Volume权限 permissions = self._get_external_volume_permissions(self._volume_name) - + # External Volume权限映射:根据操作类型确定所需权限 required_permissions = set() - + if operation in [VolumePermission.READ, VolumePermission.LIST]: required_permissions.add("read") elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: required_permissions.add("write") - + # 检查是否有所需的所有权限 has_permission = required_permissions.issubset(permissions) - + logger.debug(f"External Volume permission check for {self._volume_name}, operation {operation.name}: " f"required={required_permissions}, has={permissions}, granted={has_permission}") - + # 如果权限检查失败,尝试备选验证 if not has_permission: logger.info(f"Direct permission check failed for {self._volume_name}, trying fallback verification") - + # 备选验证:尝试列出Volume来验证基本访问权限 try: with self._connection.cursor() as cursor: @@ -193,43 +198,43 @@ class VolumePermissionManager: return True except Exception as fallback_e: logger.warning(f"Fallback verification failed for {self._volume_name}: {fallback_e}") - + return has_permission - + except Exception as e: logger.error(f"External volume permission check failed for {self._volume_name}: {e}") logger.info(f"External Volume permission check failed, but permission checking is disabled in this version") return False - + def _get_table_permissions(self, table_name: str) -> Set[str]: """获取用户对指定表的权限 - + Args: table_name: 表名 - + Returns: 用户对该表的权限集合 """ cache_key = f"table:{table_name}" - + if cache_key in self._permission_cache: return self._permission_cache[cache_key] - + permissions = set() - + try: with self._connection.cursor() as cursor: # 使用正确的ClickZetta语法检查当前用户权限 cursor.execute("SHOW GRANTS") grants = cursor.fetchall() - + # 解析权限结果,查找对该表的权限 for grant in grants: if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) privilege = grant[0].upper() object_type = grant[1].upper() if len(grant) > 1 else "" object_name = grant[2] if len(grant) > 2 else "" - + # 检查是否是对该表的权限 if object_type == "TABLE" and object_name == table_name: if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: @@ -244,7 +249,7 @@ class VolumePermissionManager: permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) else: permissions.add(privilege) - + # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 if not permissions: try: @@ -252,21 +257,21 @@ class VolumePermissionManager: permissions.add("SELECT") except Exception: logger.debug(f"Cannot query table {table_name}, no SELECT permission") - + except Exception as e: logger.warning(f"Could not check table permissions for {table_name}: {e}") # 安全默认:权限检查失败时拒绝访问 pass - + # 缓存权限信息 self._permission_cache[cache_key] = permissions return permissions - + def _get_current_username(self) -> str: """获取当前用户名""" if self._current_username: return self._current_username - + try: with self._connection.cursor() as cursor: cursor.execute("SELECT CURRENT_USER()") @@ -276,73 +281,74 @@ class VolumePermissionManager: return self._current_username except Exception as e: logger.error(f"Failed to get current username: {e}") - + return "unknown" - + def _get_user_permissions(self, username: str) -> Set[str]: """获取用户的基本权限集合""" cache_key = f"user_permissions:{username}" - + if cache_key in self._permission_cache: return self._permission_cache[cache_key] - + permissions = set() - + try: with self._connection.cursor() as cursor: # 使用正确的ClickZetta语法检查当前用户权限 cursor.execute("SHOW GRANTS") grants = cursor.fetchall() - + # 解析权限结果,查找用户的基本权限 for grant in grants: if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) privilege = grant[0].upper() object_type = grant[1].upper() if len(grant) > 1 else "" - + # 收集所有相关权限 if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: if privilege == "ALL": permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) else: permissions.add(privilege) - + except Exception as e: logger.warning(f"Could not check user permissions for {username}: {e}") # 安全默认:权限检查失败时拒绝访问 pass - + # 缓存权限信息 self._permission_cache[cache_key] = permissions return permissions - + def _get_external_volume_permissions(self, volume_name: str) -> Set[str]: """获取用户对指定External Volume的权限 - + Args: volume_name: External Volume名称 - + Returns: 用户对该Volume的权限集合 """ cache_key = f"external_volume:{volume_name}" - + if cache_key in self._permission_cache: return self._permission_cache[cache_key] - + permissions = set() - + try: with self._connection.cursor() as cursor: # 使用正确的ClickZetta语法检查Volume权限 logger.info(f"Checking permissions for volume: {volume_name}") cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") grants = cursor.fetchall() - + logger.info(f"Raw grants result for {volume_name}: {grants}") - + # 解析权限结果 - # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, grantee_name, grantor_name, grant_option, granted_time) + # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, + # grantee_name, grantor_name, grant_option, granted_time) for grant in grants: logger.info(f"Processing grant: {grant}") if len(grant) >= 5: @@ -350,15 +356,19 @@ class VolumePermissionManager: privilege = grant[1].upper() granted_on = grant[3] object_name = grant[4] - - logger.info(f"Grant details - type: {granted_type}, privilege: {privilege}, granted_on: {granted_on}, object_name: {object_name}") - + + logger.info( + f"Grant details - type: {granted_type}, privilege: {privilege}, " + f"granted_on: {granted_on}, object_name: {object_name}" + ) + # 检查是否是对该Volume的权限或者是层级权限 - if (granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)) or \ - (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): - + if ((granted_type == "PRIVILEGE" and granted_on == "VOLUME" and + object_name.endswith(volume_name)) or + (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME")): + logger.info(f"Matching grant found for {volume_name}") - + if "READ" in privilege: permissions.add("read") logger.info(f"Added READ permission for {volume_name}") @@ -371,9 +381,9 @@ class VolumePermissionManager: if privilege == "ALL": permissions.update(["read", "write", "alter"]) logger.info(f"Added ALL permissions for {volume_name}") - + logger.info(f"Final permissions for {volume_name}: {permissions}") - + # 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限 if not permissions: try: @@ -386,7 +396,7 @@ class VolumePermissionManager: break except Exception: logger.debug(f"Cannot access volume {volume_name}, no basic permission") - + except Exception as e: logger.warning(f"Could not check external volume permissions for {volume_name}: {e}") # 在权限检查失败时,尝试基本的Volume访问验证 @@ -404,102 +414,102 @@ class VolumePermissionManager: logger.warning(f"Basic volume access check failed for {volume_name}: {basic_e}") # 最后的备选方案:假设有基本权限 permissions.add("read") - + # 缓存权限信息 self._permission_cache[cache_key] = permissions return permissions - + def clear_permission_cache(self): """清空权限缓存""" self._permission_cache.clear() logger.debug("Permission cache cleared") - + def get_permission_summary(self, dataset_id: Optional[str] = None) -> Dict[str, bool]: """获取权限摘要 - + Args: dataset_id: 数据集ID (用于table volume) - + Returns: 权限摘要字典 """ summary = {} - + for operation in VolumePermission: summary[operation.name.lower()] = self.check_permission(operation, dataset_id) - + return summary - + def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: """检查文件路径的权限继承 - + Args: file_path: 文件路径 operation: 要执行的操作 - + Returns: True if user has permission, False otherwise """ try: # 解析文件路径 path_parts = file_path.strip("/").split("/") - + if not path_parts: logger.warning("Invalid file path for permission inheritance check") return False - + # 对于Table Volume,第一层是dataset_id if self._volume_type == "table": if len(path_parts) < 1: return False - + dataset_id = path_parts[0] - + # 检查对dataset的权限 has_dataset_permission = self.check_permission(operation, dataset_id) - + if not has_dataset_permission: logger.debug(f"Permission denied for dataset {dataset_id}") return False - + # 检查路径遍历攻击 if self._contains_path_traversal(file_path): logger.warning(f"Path traversal attack detected: {file_path}") return False - + # 检查是否访问敏感目录 if self._is_sensitive_path(file_path): logger.warning(f"Access to sensitive path denied: {file_path}") return False - + logger.debug(f"Permission inherited for path {file_path}") return True - + elif self._volume_type == "user": # User Volume的权限继承 current_user = self._get_current_username() - + # 检查是否试图访问其他用户的目录 if len(path_parts) > 1 and path_parts[0] != current_user: logger.warning(f"User {current_user} attempted to access {path_parts[0]}'s directory") return False - + # 检查基本权限 return self.check_permission(operation) - + elif self._volume_type == "external": # External Volume的权限继承 # 检查对External Volume的权限 return self.check_permission(operation) - + else: logger.warning(f"Unknown volume type for permission inheritance: {self._volume_type}") return False - + except Exception as e: logger.error(f"Permission inheritance check failed: {e}") return False - + def _contains_path_traversal(self, file_path: str) -> bool: """检查路径是否包含路径遍历攻击""" # 检查常见的路径遍历模式 @@ -509,23 +519,23 @@ class VolumePermissionManager: "%2e%2e%2f", "%2e%2e%5c", "....//", "....\\\\", ] - + file_path_lower = file_path.lower() - + for pattern in traversal_patterns: if pattern in file_path_lower: return True - + # 检查绝对路径 if file_path.startswith("/") or file_path.startswith("\\"): return True - + # 检查Windows驱动器路径 if len(file_path) >= 2 and file_path[1] == ":": return True - + return False - + def _is_sensitive_path(self, file_path: str) -> bool: """检查路径是否为敏感路径""" sensitive_patterns = [ @@ -533,22 +543,22 @@ class VolumePermissionManager: "private", "key", "certificate", "cert", "ssl", "database", "backup", "dump", "log", "tmp" ] - + file_path_lower = file_path.lower() - + for pattern in sensitive_patterns: if pattern in file_path_lower: return True - + return False - + def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: """验证操作权限 - + Args: operation: 操作名称 (save|load|exists|delete|scan) dataset_id: 数据集ID - + Returns: True if operation is allowed, False otherwise """ @@ -562,18 +572,18 @@ class VolumePermissionManager: "delete": VolumePermission.DELETE, "scan": VolumePermission.LIST, } - + if operation not in operation_mapping: logger.warning(f"Unknown operation: {operation}") return False - + volume_permission = operation_mapping[operation] return self.check_permission(volume_permission, dataset_id) class VolumePermissionError(Exception): """Volume权限错误异常""" - + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): self.operation = operation self.volume_type = volume_type @@ -581,16 +591,16 @@ class VolumePermissionError(Exception): super().__init__(message) -def check_volume_permission(permission_manager: VolumePermissionManager, - operation: str, +def check_volume_permission(permission_manager: VolumePermissionManager, + operation: str, dataset_id: Optional[str] = None) -> None: """权限检查装饰器函数 - + Args: permission_manager: 权限管理器 operation: 操作名称 dataset_id: 数据集ID - + Raises: VolumePermissionError: 如果没有权限 """ @@ -598,10 +608,10 @@ def check_volume_permission(permission_manager: VolumePermissionManager, error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" if dataset_id: error_message += f" (dataset: {dataset_id})" - + raise VolumePermissionError( error_message, operation=operation, volume_type=permission_manager._volume_type, dataset_id=dataset_id - ) \ No newline at end of file + ) diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py index b6ba4b3692..2ae8b27210 100644 --- a/api/tests/integration_tests/storage/test_clickzetta_volume.py +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -177,4 +177,4 @@ class TestClickZettaVolumeStorage(unittest.TestCase): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py index 751e013aed..1ca95c4f72 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py +++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py @@ -234,4 +234,4 @@ class TestClickzettaVector(AbstractVectorTest): # Clean up vector_store.delete_by_metadata_field("lang", "chinese") - vector_store.delete_by_metadata_field("lang", "english") \ No newline at end of file + vector_store.delete_by_metadata_field("lang", "english") diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py index 963df6e0f6..b8a83d63c0 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py +++ b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py @@ -162,4 +162,4 @@ def main(): return 1 if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 19be76f4ae..421dd2c23d 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -87,11 +87,12 @@ x-shared-env: &shared-api-worker-env WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} STORAGE_TYPE: ${STORAGE_TYPE:-opendal} + OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} + OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user} CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-} CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_} - OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} - OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} + CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km} S3_ENDPOINT: ${S3_ENDPOINT:-} S3_REGION: ${S3_REGION:-us-east-1} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}