Fix remaining Python style and linting issues

- Fix line length violation in middleware config description
- Fix RUF013 type annotation to use union syntax
- Complete all Python style and linting fixes for CI checks
- Resolve formatter and linter warnings

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent c3851595d0
commit b5a3f1d5e0

File diff suppressed because it is too large Load Diff

@ -64,8 +64,8 @@ class StorageConfig(BaseSettings):
"local", "local",
] = Field( ] = Field(
description="Type of storage to use." description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'clickzetta-volume', 'google-storage', " " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'clickzetta-volume', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.", "'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal", default="opendal",
) )

@ -8,57 +8,57 @@ from pydantic_settings import BaseSettings
class ClickZettaVolumeStorageConfig(BaseSettings): class ClickZettaVolumeStorageConfig(BaseSettings):
"""Configuration for ClickZetta Volume storage.""" """Configuration for ClickZetta Volume storage."""
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
description="Username for ClickZetta Volume authentication", description="Username for ClickZetta Volume authentication",
default=None, default=None,
) )
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
description="Password for ClickZetta Volume authentication", description="Password for ClickZetta Volume authentication",
default=None, default=None,
) )
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
description="ClickZetta instance identifier", description="ClickZetta instance identifier",
default=None, default=None,
) )
CLICKZETTA_VOLUME_SERVICE: str = Field( CLICKZETTA_VOLUME_SERVICE: str = Field(
description="ClickZetta service endpoint", description="ClickZetta service endpoint",
default="api.clickzetta.com", default="api.clickzetta.com",
) )
CLICKZETTA_VOLUME_WORKSPACE: str = Field( CLICKZETTA_VOLUME_WORKSPACE: str = Field(
description="ClickZetta workspace name", description="ClickZetta workspace name",
default="quick_start", default="quick_start",
) )
CLICKZETTA_VOLUME_VCLUSTER: str = Field( CLICKZETTA_VOLUME_VCLUSTER: str = Field(
description="ClickZetta virtual cluster name", description="ClickZetta virtual cluster name",
default="default_ap", default="default_ap",
) )
CLICKZETTA_VOLUME_SCHEMA: str = Field( CLICKZETTA_VOLUME_SCHEMA: str = Field(
description="ClickZetta schema name", description="ClickZetta schema name",
default="dify", default="dify",
) )
CLICKZETTA_VOLUME_TYPE: str = Field( CLICKZETTA_VOLUME_TYPE: str = Field(
description="ClickZetta volume type (table|user|external)", description="ClickZetta volume type (table|user|external)",
default="user", default="user",
) )
CLICKZETTA_VOLUME_NAME: Optional[str] = Field( CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
description="ClickZetta volume name for external volumes", description="ClickZetta volume name for external volumes",
default=None, default=None,
) )
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field( CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
description="Prefix for ClickZetta volume table names", description="Prefix for ClickZetta volume table names",
default="dataset_", default="dataset_",
) )
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field( CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
description="Directory prefix for User Volume to organize Dify files", description="Directory prefix for User Volume to organize Dify files",
default="dify_km", default="dify_km",

@ -67,4 +67,3 @@ class ClickzettaConfig(BaseModel):
description="Distance function for vector similarity: l2_distance or cosine_distance", description="Distance function for vector similarity: l2_distance or cosine_distance",
default="cosine_distance", default="cosine_distance",
) )

@ -80,7 +80,7 @@ class Storage:
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set # and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
volume_config = ClickZettaVolumeConfig() volume_config = ClickZettaVolumeConfig()
return ClickZettaVolumeStorage(volume_config) return ClickZettaVolumeStorage(volume_config)
return create_clickzetta_volume_storage return create_clickzetta_volume_storage
case _: case _:
raise ValueError(f"unsupported storage type {storage_type}") raise ValueError(f"unsupported storage type {storage_type}")

@ -16,6 +16,7 @@ import clickzetta # type: ignore[import]
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage
from .volume_permissions import VolumePermissionManager, check_volume_permission from .volume_permissions import VolumePermissionManager, check_volume_permission
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ClickZettaVolumeConfig(BaseModel): class ClickZettaVolumeConfig(BaseModel):
"""Configuration for ClickZetta Volume storage.""" """Configuration for ClickZetta Volume storage."""
username: str username: str
password: str password: str
instance: str instance: str
@ -36,52 +37,51 @@ class ClickZettaVolumeConfig(BaseModel):
table_prefix: str = "dataset_" # Prefix for table volume names table_prefix: str = "dataset_" # Prefix for table volume names
dify_prefix: str = "dify_km" # Directory prefix for User Volume dify_prefix: str = "dify_km" # Directory prefix for User Volume
permission_check: bool = True # Enable/disable permission checking permission_check: bool = True # Enable/disable permission checking
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
"""Validate the configuration values. """Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables, This method will first try to use CLICKZETTA_VOLUME_* environment variables,
then fall back to CLICKZETTA_* environment variables (for vector DB config). then fall back to CLICKZETTA_* environment variables (for vector DB config).
""" """
import os import os
# Helper function to get environment variable with fallback # Helper function to get environment variable with fallback
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str = None) -> str: def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
# First try CLICKZETTA_VOLUME_* specific config # First try CLICKZETTA_VOLUME_* specific config
volume_value = values.get(volume_key.lower().replace('clickzetta_volume_', '')) volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
if volume_value: if volume_value:
return volume_value return volume_value
# Then try environment variables # Then try environment variables
volume_env = os.getenv(volume_key) volume_env = os.getenv(volume_key)
if volume_env: if volume_env:
return volume_env return volume_env
# Fall back to existing CLICKZETTA_* config # Fall back to existing CLICKZETTA_* config
fallback_env = os.getenv(fallback_key) fallback_env = os.getenv(fallback_key)
if fallback_env: if fallback_env:
return fallback_env return fallback_env
return default return default
# Apply environment variables with fallback to existing CLICKZETTA_* config # Apply environment variables with fallback to existing CLICKZETTA_* config
values.setdefault("username", get_env_with_fallback( values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
"CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME")) values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
values.setdefault("password", get_env_with_fallback( values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
"CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD")) values.setdefault(
values.setdefault("instance", get_env_with_fallback( "service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
"CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE")) )
values.setdefault("service", get_env_with_fallback( values.setdefault(
"CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")) "workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
values.setdefault("workspace", get_env_with_fallback( )
"CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")) values.setdefault(
values.setdefault("vcluster", get_env_with_fallback( "vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
"CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")) )
values.setdefault("schema_name", get_env_with_fallback( values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
"CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
# Volume-specific configurations (no fallback to vector DB config) # Volume-specific configurations (no fallback to vector DB config)
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table")) values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME")) values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
@ -89,7 +89,7 @@ class ClickZettaVolumeConfig(BaseModel):
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km")) values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
# 暂时禁用权限检查功能直接设置为false # 暂时禁用权限检查功能直接设置为false
values.setdefault("permission_check", False) values.setdefault("permission_check", False)
# Validate required fields # Validate required fields
if not values.get("username"): if not values.get("username"):
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required") raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
@ -97,24 +97,24 @@ class ClickZettaVolumeConfig(BaseModel):
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required") raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
if not values.get("instance"): if not values.get("instance"):
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required") raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
# Validate volume type # Validate volume type
volume_type = values["volume_type"] volume_type = values["volume_type"]
if volume_type not in ["table", "user", "external"]: if volume_type not in ["table", "user", "external"]:
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external") raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
if volume_type == "external" and not values.get("volume_name"): if volume_type == "external" and not values.get("volume_name"):
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type") raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
return values return values
class ClickZettaVolumeStorage(BaseStorage): class ClickZettaVolumeStorage(BaseStorage):
"""ClickZetta Volume storage implementation.""" """ClickZetta Volume storage implementation."""
def __init__(self, config: ClickZettaVolumeConfig): def __init__(self, config: ClickZettaVolumeConfig):
"""Initialize ClickZetta Volume storage. """Initialize ClickZetta Volume storage.
Args: Args:
config: ClickZetta Volume configuration config: ClickZetta Volume configuration
""" """
@ -123,9 +123,9 @@ class ClickZettaVolumeStorage(BaseStorage):
self._permission_manager = None self._permission_manager = None
self._init_connection() self._init_connection()
self._init_permission_manager() self._init_permission_manager()
logger.info(f"ClickZetta Volume storage initialized with type: {config.volume_type}") logger.info(f"ClickZetta Volume storage initialized with type: {config.volume_type}")
def _init_connection(self): def _init_connection(self):
"""Initialize ClickZetta connection.""" """Initialize ClickZetta connection."""
try: try:
@ -136,26 +136,24 @@ class ClickZettaVolumeStorage(BaseStorage):
service=self._config.service, service=self._config.service,
workspace=self._config.workspace, workspace=self._config.workspace,
vcluster=self._config.vcluster, vcluster=self._config.vcluster,
schema=self._config.schema_name schema=self._config.schema_name,
) )
logger.debug("ClickZetta connection established") logger.debug("ClickZetta connection established")
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to ClickZetta: {e}") logger.error(f"Failed to connect to ClickZetta: {e}")
raise raise
def _init_permission_manager(self): def _init_permission_manager(self):
"""Initialize permission manager.""" """Initialize permission manager."""
try: try:
self._permission_manager = VolumePermissionManager( self._permission_manager = VolumePermissionManager(
self._connection, self._connection, self._config.volume_type, self._config.volume_name
self._config.volume_type,
self._config.volume_name
) )
logger.debug("Permission manager initialized") logger.debug("Permission manager initialized")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize permission manager: {e}") logger.error(f"Failed to initialize permission manager: {e}")
raise raise
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
"""Get the appropriate volume path based on volume type.""" """Get the appropriate volume path based on volume type."""
if self._config.volume_type == "user": if self._config.volume_type == "user":
@ -166,7 +164,7 @@ class ClickZettaVolumeStorage(BaseStorage):
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files"]: if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files"]:
# Use User Volume with dify prefix for special directories # Use User Volume with dify prefix for special directories
return f"{self._config.dify_prefix}/{filename}" return f"{self._config.dify_prefix}/{filename}"
if dataset_id: if dataset_id:
return f"{self._config.table_prefix}{dataset_id}/{filename}" return f"{self._config.table_prefix}{dataset_id}/{filename}"
else: else:
@ -180,7 +178,7 @@ class ClickZettaVolumeStorage(BaseStorage):
return filename return filename
else: else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}") raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
"""Get SQL prefix for volume operations.""" """Get SQL prefix for volume operations."""
if self._config.volume_type == "user": if self._config.volume_type == "user":
@ -191,7 +189,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# These should use USER VOLUME for better compatibility # These should use USER VOLUME for better compatibility
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files"]: if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files"]:
return "USER VOLUME" return "USER VOLUME"
# Only use TABLE VOLUME for actual dataset-specific paths # Only use TABLE VOLUME for actual dataset-specific paths
# like "dataset_12345/file.pdf" or paths with dataset_ prefix # like "dataset_12345/file.pdf" or paths with dataset_ prefix
if dataset_id: if dataset_id:
@ -204,7 +202,7 @@ class ClickZettaVolumeStorage(BaseStorage):
return f"VOLUME {self._config.volume_name}" return f"VOLUME {self._config.volume_name}"
else: else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}") raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _execute_sql(self, sql: str, fetch: bool = False): def _execute_sql(self, sql: str, fetch: bool = False):
"""Execute SQL command.""" """Execute SQL command."""
try: try:
@ -216,23 +214,23 @@ class ClickZettaVolumeStorage(BaseStorage):
except Exception as e: except Exception as e:
logger.error(f"SQL execution failed: {sql}, Error: {e}") logger.error(f"SQL execution failed: {sql}, Error: {e}")
raise raise
def _ensure_table_volume_exists(self, dataset_id: str) -> None: def _ensure_table_volume_exists(self, dataset_id: str) -> None:
"""Ensure table volume exists for the given dataset_id.""" """Ensure table volume exists for the given dataset_id."""
if self._config.volume_type != "table" or not dataset_id: if self._config.volume_type != "table" or not dataset_id:
return return
# Skip for upload_files and other special directories that use USER VOLUME # Skip for upload_files and other special directories that use USER VOLUME
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files"]: if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files"]:
return return
table_name = f"{self._config.table_prefix}{dataset_id}" table_name = f"{self._config.table_prefix}{dataset_id}"
try: try:
# Check if table exists # Check if table exists
check_sql = f"SHOW TABLES LIKE '{table_name}'" check_sql = f"SHOW TABLES LIKE '{table_name}'"
result = self._execute_sql(check_sql, fetch=True) result = self._execute_sql(check_sql, fetch=True)
if not result: if not result:
# Create table with volume # Create table with volume
create_sql = f""" create_sql = f"""
@ -246,15 +244,15 @@ class ClickZettaVolumeStorage(BaseStorage):
""" """
self._execute_sql(create_sql) self._execute_sql(create_sql)
logger.info(f"Created table volume: {table_name}") logger.info(f"Created table volume: {table_name}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to create table volume {table_name}: {e}") logger.warning(f"Failed to create table volume {table_name}: {e}")
# Don't raise exception, let the operation continue # Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions # The table might exist but not be visible due to permissions
def save(self, filename: str, data: bytes) -> None: def save(self, filename: str, data: bytes) -> None:
"""Save data to ClickZetta Volume. """Save data to ClickZetta Volume.
Args: Args:
filename: File path in volume filename: File path in volume
data: File content as bytes data: File content as bytes
@ -264,53 +262,53 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table": if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1) parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix): if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):] dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1] filename = parts[1]
else: else:
dataset_id = parts[0] dataset_id = parts[0]
filename = parts[1] filename = parts[1]
# Ensure table volume exists (for table volumes) # Ensure table volume exists (for table volumes)
if dataset_id: if dataset_id:
self._ensure_table_volume_exists(dataset_id) self._ensure_table_volume_exists(dataset_id)
# Check permissions (if enabled) # Check permissions (if enabled)
if self._config.permission_check: if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME # Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files"]: if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files"]:
check_volume_permission(self._permission_manager, "save", dataset_id) check_volume_permission(self._permission_manager, "save", dataset_id)
# Write data to temporary file # Write data to temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file: with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(data) temp_file.write(data)
temp_file_path = temp_file.name temp_file_path = temp_file.name
try: try:
# Upload to volume # Upload to volume
volume_prefix = self._get_volume_sql_prefix(dataset_id) volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix) # Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id) volume_path = self._get_volume_path(filename, dataset_id)
actual_filename = volume_path.split('/')[-1] if '/' in volume_path else volume_path actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
# For User Volume, use the full path with dify_km prefix # For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME": if volume_prefix == "USER VOLUME":
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'" sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
else: else:
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'" sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
self._execute_sql(sql) self._execute_sql(sql)
logger.debug(f"File {filename} saved to ClickZetta Volume at path {volume_path}") logger.debug(f"File {filename} saved to ClickZetta Volume at path {volume_path}")
finally: finally:
# Clean up temporary file # Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True) Path(temp_file_path).unlink(missing_ok=True)
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume. """Load file content from ClickZetta Volume.
Args: Args:
filename: File path in volume filename: File path in volume
Returns: Returns:
File content as bytes File content as bytes
""" """
@ -319,33 +317,33 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table": if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1) parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix): if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):] dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1] filename = parts[1]
else: else:
dataset_id = parts[0] dataset_id = parts[0]
filename = parts[1] filename = parts[1]
# Check permissions (if enabled) # Check permissions (if enabled)
if self._config.permission_check: if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME # Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files"]: if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files"]:
check_volume_permission(self._permission_manager, "load_once", dataset_id) check_volume_permission(self._permission_manager, "load_once", dataset_id)
# Download to temporary directory # Download to temporary directory
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
volume_prefix = self._get_volume_sql_prefix(dataset_id) volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix) # Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id) volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix # For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME": if volume_prefix == "USER VOLUME":
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'" sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
else: else:
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'" sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
self._execute_sql(sql) self._execute_sql(sql)
# Find the downloaded file (may be in subdirectories) # Find the downloaded file (may be in subdirectories)
downloaded_file = None downloaded_file = None
for root, dirs, files in os.walk(temp_dir): for root, dirs, files in os.walk(temp_dir):
@ -355,52 +353,52 @@ class ClickZettaVolumeStorage(BaseStorage):
break break
if downloaded_file: if downloaded_file:
break break
if not downloaded_file or not downloaded_file.exists(): if not downloaded_file or not downloaded_file.exists():
raise FileNotFoundError(f"Downloaded file not found: {filename}") raise FileNotFoundError(f"Downloaded file not found: {filename}")
content = downloaded_file.read_bytes() content = downloaded_file.read_bytes()
logger.debug(f"File {filename} loaded from ClickZetta Volume") logger.debug(f"File {filename} loaded from ClickZetta Volume")
return content return content
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume. """Load file as stream from ClickZetta Volume.
Args: Args:
filename: File path in volume filename: File path in volume
Yields: Yields:
File content chunks File content chunks
""" """
content = self.load_once(filename) content = self.load_once(filename)
batch_size = 4096 batch_size = 4096
stream = BytesIO(content) stream = BytesIO(content)
while chunk := stream.read(batch_size): while chunk := stream.read(batch_size):
yield chunk yield chunk
logger.debug(f"File {filename} loaded as stream from ClickZetta Volume") logger.debug(f"File {filename} loaded as stream from ClickZetta Volume")
def download(self, filename: str, target_filepath: str): def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path. """Download file from ClickZetta Volume to local path.
Args: Args:
filename: File path in volume filename: File path in volume
target_filepath: Local target file path target_filepath: Local target file path
""" """
content = self.load_once(filename) content = self.load_once(filename)
with Path(target_filepath).open("wb") as f: with Path(target_filepath).open("wb") as f:
f.write(content) f.write(content)
logger.debug(f"File {filename} downloaded from ClickZetta Volume to {target_filepath}") logger.debug(f"File {filename} downloaded from ClickZetta Volume to {target_filepath}")
def exists(self, filename: str) -> bool: def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume. """Check if file exists in ClickZetta Volume.
Args: Args:
filename: File path in volume filename: File path in volume
Returns: Returns:
True if file exists, False otherwise True if file exists, False otherwise
""" """
@ -410,76 +408,76 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table": if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1) parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix): if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):] dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1] filename = parts[1]
else: else:
dataset_id = parts[0] dataset_id = parts[0]
filename = parts[1] filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id) volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix) # Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id) volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix # For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME": if volume_prefix == "USER VOLUME":
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'" sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
else: else:
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'" sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
rows = self._execute_sql(sql, fetch=True) rows = self._execute_sql(sql, fetch=True)
exists = len(rows) > 0 exists = len(rows) > 0
logger.debug(f"File {filename} exists check: {exists}") logger.debug(f"File {filename} exists check: {exists}")
return exists return exists
except Exception as e: except Exception as e:
logger.warning(f"Error checking file existence for {filename}: {e}") logger.warning(f"Error checking file existence for {filename}: {e}")
return False return False
def delete(self, filename: str): def delete(self, filename: str):
"""Delete file from ClickZetta Volume. """Delete file from ClickZetta Volume.
Args: Args:
filename: File path in volume filename: File path in volume
""" """
if not self.exists(filename): if not self.exists(filename):
logger.debug(f"File {filename} not found, skip delete") logger.debug(f"File {filename} not found, skip delete")
return return
# Extract dataset_id from filename if present # Extract dataset_id from filename if present
dataset_id = None dataset_id = None
if "/" in filename and self._config.volume_type == "table": if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1) parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix): if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):] dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1] filename = parts[1]
else: else:
dataset_id = parts[0] dataset_id = parts[0]
filename = parts[1] filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id) volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix) # Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id) volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix # For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME": if volume_prefix == "USER VOLUME":
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'" sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
else: else:
sql = f"REMOVE {volume_prefix} FILE '{filename}'" sql = f"REMOVE {volume_prefix} FILE '{filename}'"
self._execute_sql(sql) self._execute_sql(sql)
logger.debug(f"File {filename} deleted from ClickZetta Volume") logger.debug(f"File {filename} deleted from ClickZetta Volume")
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume. """Scan files and directories in ClickZetta Volume.
Args: Args:
path: Path to scan (dataset_id for table volumes) path: Path to scan (dataset_id for table volumes)
files: Include files in results files: Include files in results
directories: Include directories in results directories: Include directories in results
Returns: Returns:
List of file/directory paths List of file/directory paths
""" """
@ -489,9 +487,9 @@ class ClickZettaVolumeStorage(BaseStorage):
if self._config.volume_type == "table": if self._config.volume_type == "table":
dataset_id = path dataset_id = path
path = "" # Root of the table volume path = "" # Root of the table volume
volume_prefix = self._get_volume_sql_prefix(dataset_id) volume_prefix = self._get_volume_sql_prefix(dataset_id)
# For User Volume, add dify prefix to path # For User Volume, add dify prefix to path
if volume_prefix == "USER VOLUME": if volume_prefix == "USER VOLUME":
if path: if path:
@ -504,26 +502,24 @@ class ClickZettaVolumeStorage(BaseStorage):
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'" sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
else: else:
sql = f"LIST {volume_prefix}" sql = f"LIST {volume_prefix}"
rows = self._execute_sql(sql, fetch=True) rows = self._execute_sql(sql, fetch=True)
result = [] result = []
for row in rows: for row in rows:
file_path = row[0] # relative_path column file_path = row[0] # relative_path column
# For User Volume, remove dify prefix from results # For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/" dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash):] # Remove prefix file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
if files and not file_path.endswith("/"): if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
elif directories and file_path.endswith("/"):
result.append(file_path) result.append(file_path)
logger.debug(f"Scanned {len(result)} items in path {path}") logger.debug(f"Scanned {len(result)} items in path {path}")
return result return result
except Exception as e: except Exception as e:
logger.error(f"Error scanning path {path}: {e}") logger.error(f"Error scanning path {path}: {e}")
return [] return []

@ -6,26 +6,27 @@
import json import json
import logging import logging
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict
from enum import Enum from enum import Enum
from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FileStatus(Enum): class FileStatus(Enum):
"""文件状态枚举""" """文件状态枚举"""
ACTIVE = "active" # 活跃状态
ACTIVE = "active" # 活跃状态
ARCHIVED = "archived" # 已归档 ARCHIVED = "archived" # 已归档
DELETED = "deleted" # 已删除(软删除) DELETED = "deleted" # 已删除(软删除)
BACKUP = "backup" # 备份文件 BACKUP = "backup" # 备份文件
@dataclass @dataclass
class FileMetadata: class FileMetadata:
"""文件元数据""" """文件元数据"""
filename: str filename: str
size: int size: int
created_at: datetime created_at: datetime
@ -33,33 +34,33 @@ class FileMetadata:
version: int version: int
status: FileStatus status: FileStatus
checksum: Optional[str] = None checksum: Optional[str] = None
tags: Optional[Dict[str, str]] = None tags: Optional[dict[str, str]] = None
parent_version: Optional[int] = None parent_version: Optional[int] = None
def to_dict(self) -> Dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
data = asdict(self) data = asdict(self)
data['created_at'] = self.created_at.isoformat() data["created_at"] = self.created_at.isoformat()
data['modified_at'] = self.modified_at.isoformat() data["modified_at"] = self.modified_at.isoformat()
data['status'] = self.status.value data["status"] = self.status.value
return data return data
@classmethod @classmethod
def from_dict(cls, data: Dict) -> 'FileMetadata': def from_dict(cls, data: dict) -> "FileMetadata":
"""从字典创建实例""" """从字典创建实例"""
data = data.copy() data = data.copy()
data['created_at'] = datetime.fromisoformat(data['created_at']) data["created_at"] = datetime.fromisoformat(data["created_at"])
data['modified_at'] = datetime.fromisoformat(data['modified_at']) data["modified_at"] = datetime.fromisoformat(data["modified_at"])
data['status'] = FileStatus(data['status']) data["status"] = FileStatus(data["status"])
return cls(**data) return cls(**data)
class FileLifecycleManager: class FileLifecycleManager:
"""文件生命周期管理器""" """文件生命周期管理器"""
def __init__(self, storage, dataset_id: Optional[str] = None): def __init__(self, storage, dataset_id: Optional[str] = None):
"""初始化生命周期管理器 """初始化生命周期管理器
Args: Args:
storage: ClickZetta Volume存储实例 storage: ClickZetta Volume存储实例
dataset_id: 数据集ID用于Table Volume dataset_id: 数据集ID用于Table Volume
@ -70,61 +71,61 @@ class FileLifecycleManager:
self._version_prefix = ".versions/" self._version_prefix = ".versions/"
self._backup_prefix = ".backups/" self._backup_prefix = ".backups/"
self._deleted_prefix = ".deleted/" self._deleted_prefix = ".deleted/"
# 获取权限管理器(如果存在) # 获取权限管理器(如果存在)
self._permission_manager = getattr(storage, '_permission_manager', None) self._permission_manager = getattr(storage, "_permission_manager", None)
def save_with_lifecycle(self, filename: str, data: bytes, def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
tags: Optional[Dict[str, str]] = None) -> FileMetadata:
"""保存文件并管理生命周期 """保存文件并管理生命周期
Args: Args:
filename: 文件名 filename: 文件名
data: 文件内容 data: 文件内容
tags: 文件标签 tags: 文件标签
Returns: Returns:
文件元数据 文件元数据
""" """
# 权限检查 # 权限检查
if not self._check_permission(filename, "save"): if not self._check_permission(filename, "save"):
from .volume_permissions import VolumePermissionError from .volume_permissions import VolumePermissionError
raise VolumePermissionError( raise VolumePermissionError(
f"Permission denied for lifecycle save operation on file: {filename}", f"Permission denied for lifecycle save operation on file: {filename}",
operation="save", operation="save",
volume_type=getattr(self._storage, '_config', {}).get('volume_type', 'unknown'), volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
dataset_id=self._dataset_id dataset_id=self._dataset_id,
) )
try: try:
# 1. 检查是否存在旧版本 # 1. 检查是否存在旧版本
metadata_dict = self._load_metadata() metadata_dict = self._load_metadata()
current_metadata = metadata_dict.get(filename) current_metadata = metadata_dict.get(filename)
# 2. 如果存在旧版本,创建版本备份 # 2. 如果存在旧版本,创建版本备份
if current_metadata: if current_metadata:
self._create_version_backup(filename, current_metadata) self._create_version_backup(filename, current_metadata)
# 3. 计算文件信息 # 3. 计算文件信息
now = datetime.now() now = datetime.now()
checksum = self._calculate_checksum(data) checksum = self._calculate_checksum(data)
new_version = (current_metadata['version'] + 1) if current_metadata else 1 new_version = (current_metadata["version"] + 1) if current_metadata else 1
# 4. 保存新文件 # 4. 保存新文件
self._storage.save(filename, data) self._storage.save(filename, data)
# 5. 创建元数据 # 5. 创建元数据
created_at = now created_at = now
parent_version = None parent_version = None
if current_metadata: if current_metadata:
# 如果created_at是字符串转换为datetime # 如果created_at是字符串转换为datetime
if isinstance(current_metadata['created_at'], str): if isinstance(current_metadata["created_at"], str):
created_at = datetime.fromisoformat(current_metadata['created_at']) created_at = datetime.fromisoformat(current_metadata["created_at"])
else: else:
created_at = current_metadata['created_at'] created_at = current_metadata["created_at"]
parent_version = current_metadata['version'] parent_version = current_metadata["version"]
file_metadata = FileMetadata( file_metadata = FileMetadata(
filename=filename, filename=filename,
size=len(data), size=len(data),
@ -134,26 +135,26 @@ class FileLifecycleManager:
status=FileStatus.ACTIVE, status=FileStatus.ACTIVE,
checksum=checksum, checksum=checksum,
tags=tags or {}, tags=tags or {},
parent_version=parent_version parent_version=parent_version,
) )
# 6. 更新元数据 # 6. 更新元数据
metadata_dict[filename] = file_metadata.to_dict() metadata_dict[filename] = file_metadata.to_dict()
self._save_metadata(metadata_dict) self._save_metadata(metadata_dict)
logger.info(f"File {filename} saved with lifecycle management, version {new_version}") logger.info(f"File {filename} saved with lifecycle management, version {new_version}")
return file_metadata return file_metadata
except Exception as e: except Exception as e:
logger.error(f"Failed to save file with lifecycle: {e}") logger.error(f"Failed to save file with lifecycle: {e}")
raise raise
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
"""获取文件元数据 """获取文件元数据
Args: Args:
filename: 文件名 filename: 文件名
Returns: Returns:
文件元数据如果不存在返回None 文件元数据如果不存在返回None
""" """
@ -165,24 +166,24 @@ class FileLifecycleManager:
except Exception as e: except Exception as e:
logger.error(f"Failed to get file metadata for {filename}: {e}") logger.error(f"Failed to get file metadata for {filename}: {e}")
return None return None
def list_file_versions(self, filename: str) -> List[FileMetadata]: def list_file_versions(self, filename: str) -> list[FileMetadata]:
"""列出文件的所有版本 """列出文件的所有版本
Args: Args:
filename: 文件名 filename: 文件名
Returns: Returns:
文件版本列表按版本号排序 文件版本列表按版本号排序
""" """
try: try:
versions = [] versions = []
# 获取当前版本 # 获取当前版本
current_metadata = self.get_file_metadata(filename) current_metadata = self.get_file_metadata(filename)
if current_metadata: if current_metadata:
versions.append(current_metadata) versions.append(current_metadata)
# 获取历史版本 # 获取历史版本
version_pattern = f"{self._version_prefix}{filename}.v*" version_pattern = f"{self._version_prefix}{filename}.v*"
try: try:
@ -200,52 +201,52 @@ class FileLifecycleManager:
except: except:
# 如果无法扫描版本文件,只返回当前版本 # 如果无法扫描版本文件,只返回当前版本
pass pass
return sorted(versions, key=lambda x: x.version, reverse=True) return sorted(versions, key=lambda x: x.version, reverse=True)
except Exception as e: except Exception as e:
logger.error(f"Failed to list file versions for {filename}: {e}") logger.error(f"Failed to list file versions for {filename}: {e}")
return [] return []
def restore_version(self, filename: str, version: int) -> bool: def restore_version(self, filename: str, version: int) -> bool:
"""恢复文件到指定版本 """恢复文件到指定版本
Args: Args:
filename: 文件名 filename: 文件名
version: 要恢复的版本号 version: 要恢复的版本号
Returns: Returns:
恢复是否成功 恢复是否成功
""" """
try: try:
version_filename = f"{self._version_prefix}{filename}.v{version}" version_filename = f"{self._version_prefix}{filename}.v{version}"
# 检查版本文件是否存在 # 检查版本文件是否存在
if not self._storage.exists(version_filename): if not self._storage.exists(version_filename):
logger.warning(f"Version {version} of {filename} not found") logger.warning(f"Version {version} of {filename} not found")
return False return False
# 读取版本文件内容 # 读取版本文件内容
version_data = self._storage.load_once(version_filename) version_data = self._storage.load_once(version_filename)
# 保存当前版本为备份 # 保存当前版本为备份
current_metadata = self.get_file_metadata(filename) current_metadata = self.get_file_metadata(filename)
if current_metadata: if current_metadata:
self._create_version_backup(filename, current_metadata.to_dict()) self._create_version_backup(filename, current_metadata.to_dict())
# 恢复文件 # 恢复文件
return self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)}) return self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
except Exception as e: except Exception as e:
logger.error(f"Failed to restore {filename} to version {version}: {e}") logger.error(f"Failed to restore {filename} to version {version}: {e}")
return False return False
def archive_file(self, filename: str) -> bool: def archive_file(self, filename: str) -> bool:
"""归档文件 """归档文件
Args: Args:
filename: 文件名 filename: 文件名
Returns: Returns:
归档是否成功 归档是否成功
""" """
@ -253,32 +254,32 @@ class FileLifecycleManager:
if not self._check_permission(filename, "archive"): if not self._check_permission(filename, "archive"):
logger.warning(f"Permission denied for archive operation on file: {filename}") logger.warning(f"Permission denied for archive operation on file: {filename}")
return False return False
try: try:
# 更新文件状态为归档 # 更新文件状态为归档
metadata_dict = self._load_metadata() metadata_dict = self._load_metadata()
if filename not in metadata_dict: if filename not in metadata_dict:
logger.warning(f"File {filename} not found in metadata") logger.warning(f"File {filename} not found in metadata")
return False return False
metadata_dict[filename]['status'] = FileStatus.ARCHIVED.value metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
metadata_dict[filename]['modified_at'] = datetime.now().isoformat() metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict) self._save_metadata(metadata_dict)
logger.info(f"File {filename} archived successfully") logger.info(f"File {filename} archived successfully")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to archive file {filename}: {e}") logger.error(f"Failed to archive file {filename}: {e}")
return False return False
def soft_delete_file(self, filename: str) -> bool: def soft_delete_file(self, filename: str) -> bool:
"""软删除文件(移动到删除目录) """软删除文件(移动到删除目录)
Args: Args:
filename: 文件名 filename: 文件名
Returns: Returns:
删除是否成功 删除是否成功
""" """
@ -286,61 +287,61 @@ class FileLifecycleManager:
if not self._check_permission(filename, "delete"): if not self._check_permission(filename, "delete"):
logger.warning(f"Permission denied for soft delete operation on file: {filename}") logger.warning(f"Permission denied for soft delete operation on file: {filename}")
return False return False
try: try:
# 检查文件是否存在 # 检查文件是否存在
if not self._storage.exists(filename): if not self._storage.exists(filename):
logger.warning(f"File {filename} not found") logger.warning(f"File {filename} not found")
return False return False
# 读取文件内容 # 读取文件内容
file_data = self._storage.load_once(filename) file_data = self._storage.load_once(filename)
# 移动到删除目录 # 移动到删除目录
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self._storage.save(deleted_filename, file_data) self._storage.save(deleted_filename, file_data)
# 删除原文件 # 删除原文件
self._storage.delete(filename) self._storage.delete(filename)
# 更新元数据 # 更新元数据
metadata_dict = self._load_metadata() metadata_dict = self._load_metadata()
if filename in metadata_dict: if filename in metadata_dict:
metadata_dict[filename]['status'] = FileStatus.DELETED.value metadata_dict[filename]["status"] = FileStatus.DELETED.value
metadata_dict[filename]['modified_at'] = datetime.now().isoformat() metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict) self._save_metadata(metadata_dict)
logger.info(f"File {filename} soft deleted successfully") logger.info(f"File {filename} soft deleted successfully")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to soft delete file {filename}: {e}") logger.error(f"Failed to soft delete file {filename}: {e}")
return False return False
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int: def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
"""清理旧版本文件 """清理旧版本文件
Args: Args:
max_versions: 保留的最大版本数 max_versions: 保留的最大版本数
max_age_days: 版本文件的最大保留天数 max_age_days: 版本文件的最大保留天数
Returns: Returns:
清理的文件数量 清理的文件数量
""" """
try: try:
cleaned_count = 0 cleaned_count = 0
cutoff_date = datetime.now() - timedelta(days=max_age_days) cutoff_date = datetime.now() - timedelta(days=max_age_days)
# 获取所有版本文件 # 获取所有版本文件
try: try:
all_files = self._storage.scan(self._dataset_id or "", files=True) all_files = self._storage.scan(self._dataset_id or "", files=True)
version_files = [f for f in all_files if f.startswith(self._version_prefix)] version_files = [f for f in all_files if f.startswith(self._version_prefix)]
# 按文件分组 # 按文件分组
file_versions = {} file_versions = {}
for version_file in version_files: for version_file in version_files:
# 解析文件名和版本 # 解析文件名和版本
parts = version_file[len(self._version_prefix):].split(".v") parts = version_file[len(self._version_prefix) :].split(".v")
if len(parts) >= 2: if len(parts) >= 2:
base_filename = parts[0] base_filename = parts[0]
version_part = parts[1].split(".")[0] version_part = parts[1].split(".")[0]
@ -351,12 +352,12 @@ class FileLifecycleManager:
file_versions[base_filename].append((version_num, version_file)) file_versions[base_filename].append((version_num, version_file))
except ValueError: except ValueError:
continue continue
# 清理每个文件的旧版本 # 清理每个文件的旧版本
for base_filename, versions in file_versions.items(): for base_filename, versions in file_versions.items():
# 按版本号排序 # 按版本号排序
versions.sort(key=lambda x: x[0], reverse=True) versions.sort(key=lambda x: x[0], reverse=True)
# 保留最新的max_versions个版本删除其余的 # 保留最新的max_versions个版本删除其余的
if len(versions) > max_versions: if len(versions) > max_versions:
to_delete = versions[max_versions:] to_delete = versions[max_versions:]
@ -364,27 +365,27 @@ class FileLifecycleManager:
self._storage.delete(version_file) self._storage.delete(version_file)
cleaned_count += 1 cleaned_count += 1
logger.debug(f"Cleaned old version: {version_file}") logger.debug(f"Cleaned old version: {version_file}")
logger.info(f"Cleaned {cleaned_count} old version files") logger.info(f"Cleaned {cleaned_count} old version files")
except Exception as e: except Exception as e:
logger.warning(f"Could not scan for version files: {e}") logger.warning(f"Could not scan for version files: {e}")
return cleaned_count return cleaned_count
except Exception as e: except Exception as e:
logger.error(f"Failed to cleanup old versions: {e}") logger.error(f"Failed to cleanup old versions: {e}")
return 0 return 0
def get_storage_statistics(self) -> Dict[str, any]: def get_storage_statistics(self) -> dict[str, any]:
"""获取存储统计信息 """获取存储统计信息
Returns: Returns:
存储统计字典 存储统计字典
""" """
try: try:
metadata_dict = self._load_metadata() metadata_dict = self._load_metadata()
stats = { stats = {
"total_files": len(metadata_dict), "total_files": len(metadata_dict),
"active_files": 0, "active_files": 0,
@ -393,15 +394,15 @@ class FileLifecycleManager:
"total_size": 0, "total_size": 0,
"versions_count": 0, "versions_count": 0,
"oldest_file": None, "oldest_file": None,
"newest_file": None "newest_file": None,
} }
oldest_date = None oldest_date = None
newest_date = None newest_date = None
for filename, metadata in metadata_dict.items(): for filename, metadata in metadata_dict.items():
file_meta = FileMetadata.from_dict(metadata) file_meta = FileMetadata.from_dict(metadata)
# 统计文件状态 # 统计文件状态
if file_meta.status == FileStatus.ACTIVE: if file_meta.status == FileStatus.ACTIVE:
stats["active_files"] += 1 stats["active_files"] += 1
@ -409,84 +410,85 @@ class FileLifecycleManager:
stats["archived_files"] += 1 stats["archived_files"] += 1
elif file_meta.status == FileStatus.DELETED: elif file_meta.status == FileStatus.DELETED:
stats["deleted_files"] += 1 stats["deleted_files"] += 1
# 统计大小 # 统计大小
stats["total_size"] += file_meta.size stats["total_size"] += file_meta.size
# 统计版本 # 统计版本
stats["versions_count"] += file_meta.version stats["versions_count"] += file_meta.version
# 找出最新和最旧的文件 # 找出最新和最旧的文件
if oldest_date is None or file_meta.created_at < oldest_date: if oldest_date is None or file_meta.created_at < oldest_date:
oldest_date = file_meta.created_at oldest_date = file_meta.created_at
stats["oldest_file"] = filename stats["oldest_file"] = filename
if newest_date is None or file_meta.modified_at > newest_date: if newest_date is None or file_meta.modified_at > newest_date:
newest_date = file_meta.modified_at newest_date = file_meta.modified_at
stats["newest_file"] = filename stats["newest_file"] = filename
return stats return stats
except Exception as e: except Exception as e:
logger.error(f"Failed to get storage statistics: {e}") logger.error(f"Failed to get storage statistics: {e}")
return {} return {}
def _create_version_backup(self, filename: str, metadata: Dict): def _create_version_backup(self, filename: str, metadata: dict):
"""创建版本备份""" """创建版本备份"""
try: try:
# 读取当前文件内容 # 读取当前文件内容
current_data = self._storage.load_once(filename) current_data = self._storage.load_once(filename)
# 保存为版本文件 # 保存为版本文件
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}" version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
self._storage.save(version_filename, current_data) self._storage.save(version_filename, current_data)
logger.debug(f"Created version backup: {version_filename}") logger.debug(f"Created version backup: {version_filename}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to create version backup for {filename}: {e}") logger.warning(f"Failed to create version backup for {filename}: {e}")
def _load_metadata(self) -> Dict: def _load_metadata(self) -> dict:
"""加载元数据文件""" """加载元数据文件"""
try: try:
if self._storage.exists(self._metadata_file): if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file) metadata_content = self._storage.load_once(self._metadata_file)
return json.loads(metadata_content.decode('utf-8')) return json.loads(metadata_content.decode("utf-8"))
else: else:
return {} return {}
except Exception as e: except Exception as e:
logger.warning(f"Failed to load metadata: {e}") logger.warning(f"Failed to load metadata: {e}")
return {} return {}
def _save_metadata(self, metadata_dict: Dict): def _save_metadata(self, metadata_dict: dict):
"""保存元数据文件""" """保存元数据文件"""
try: try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
self._storage.save(self._metadata_file, metadata_content.encode('utf-8')) self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
logger.debug("Metadata saved successfully") logger.debug("Metadata saved successfully")
except Exception as e: except Exception as e:
logger.error(f"Failed to save metadata: {e}") logger.error(f"Failed to save metadata: {e}")
raise raise
def _calculate_checksum(self, data: bytes) -> str: def _calculate_checksum(self, data: bytes) -> str:
"""计算文件校验和""" """计算文件校验和"""
import hashlib import hashlib
return hashlib.md5(data).hexdigest() return hashlib.md5(data).hexdigest()
def _check_permission(self, filename: str, operation: str) -> bool: def _check_permission(self, filename: str, operation: str) -> bool:
"""检查文件操作权限 """检查文件操作权限
Args: Args:
filename: 文件名 filename: 文件名
operation: 操作类型 operation: 操作类型
Returns: Returns:
True if permission granted, False otherwise True if permission granted, False otherwise
""" """
# 如果没有权限管理器,默认允许 # 如果没有权限管理器,默认允许
if not self._permission_manager: if not self._permission_manager:
return True return True
try: try:
# 根据操作类型映射到权限 # 根据操作类型映射到权限
operation_mapping = { operation_mapping = {
@ -494,17 +496,17 @@ class FileLifecycleManager:
"load": "load_once", "load": "load_once",
"delete": "delete", "delete": "delete",
"archive": "delete", # 归档需要删除权限 "archive": "delete", # 归档需要删除权限
"restore": "save", # 恢复需要写权限 "restore": "save", # 恢复需要写权限
"cleanup": "delete", # 清理需要删除权限 "cleanup": "delete", # 清理需要删除权限
"read": "load_once", "read": "load_once",
"write": "save" "write": "save",
} }
mapped_operation = operation_mapping.get(operation, operation) mapped_operation = operation_mapping.get(operation, operation)
# 检查权限 # 检查权限
return self._permission_manager.validate_operation(mapped_operation, self._dataset_id) return self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
except Exception as e: except Exception as e:
logger.error(f"Permission check failed for {filename} operation {operation}: {e}") logger.error(f"Permission check failed for {filename} operation {operation}: {e}")
# 安全默认:权限检查失败时拒绝访问 # 安全默认:权限检查失败时拒绝访问

@ -6,13 +6,14 @@
import logging import logging
from enum import Enum from enum import Enum
from typing import Dict, Optional, Set from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class VolumePermission(Enum): class VolumePermission(Enum):
"""Volume权限类型枚举""" """Volume权限类型枚举"""
READ = "SELECT" # 对应ClickZetta的SELECT权限 READ = "SELECT" # 对应ClickZetta的SELECT权限
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限 WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
LIST = "SELECT" # 列出文件需要SELECT权限 LIST = "SELECT" # 列出文件需要SELECT权限
@ -35,18 +36,19 @@ class VolumePermissionManager:
if isinstance(connection_or_config, dict): if isinstance(connection_or_config, dict):
# 从配置字典创建连接 # 从配置字典创建连接
import clickzetta import clickzetta
config = connection_or_config config = connection_or_config
self._connection = clickzetta.connect( self._connection = clickzetta.connect(
username=config.get('username'), username=config.get("username"),
password=config.get('password'), password=config.get("password"),
instance=config.get('instance'), instance=config.get("instance"),
service=config.get('service'), service=config.get("service"),
workspace=config.get('workspace'), workspace=config.get("workspace"),
vcluster=config.get('vcluster'), vcluster=config.get("vcluster"),
schema=config.get('schema') or config.get('database') schema=config.get("schema") or config.get("database"),
) )
self._volume_type = config.get('volume_type', volume_type) self._volume_type = config.get("volume_type", volume_type)
self._volume_name = config.get('volume_name', volume_name) self._volume_name = config.get("volume_name", volume_name)
else: else:
# 直接使用连接对象 # 直接使用连接对象
self._connection = connection_or_config self._connection = connection_or_config
@ -58,7 +60,7 @@ class VolumePermissionManager:
if not self._volume_type: if not self._volume_type:
raise ValueError("volume_type is required") raise ValueError("volume_type is required")
self._permission_cache: Dict[str, Set[str]] = {} self._permission_cache: dict[str, set[str]] = {}
self._current_username = None # 将从连接中获取当前用户名 self._current_username = None # 将从连接中获取当前用户名
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
@ -119,7 +121,7 @@ class VolumePermissionManager:
except Exception as e: except Exception as e:
logger.error(f"User Volume permission check failed: {e}") logger.error(f"User Volume permission check failed: {e}")
# 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示 # 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示
logger.info(f"User Volume permission check failed, but permission checking is disabled in this version") logger.info("User Volume permission check failed, but permission checking is disabled in this version")
return False return False
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
@ -144,8 +146,10 @@ class VolumePermissionManager:
# 检查是否有所需的所有权限 # 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions) has_permission = required_permissions.issubset(permissions)
logger.debug(f"Table Volume permission check for {table_name}, operation {operation.name}: " logger.debug(
f"required={required_permissions}, has={permissions}, granted={has_permission}") f"Table Volume permission check for {table_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}"
)
return has_permission return has_permission
@ -180,8 +184,10 @@ class VolumePermissionManager:
# 检查是否有所需的所有权限 # 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions) has_permission = required_permissions.issubset(permissions)
logger.debug(f"External Volume permission check for {self._volume_name}, operation {operation.name}: " logger.debug(
f"required={required_permissions}, has={permissions}, granted={has_permission}") f"External Volume permission check for {self._volume_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}"
)
# 如果权限检查失败,尝试备选验证 # 如果权限检查失败,尝试备选验证
if not has_permission: if not has_permission:
@ -203,10 +209,10 @@ class VolumePermissionManager:
except Exception as e: except Exception as e:
logger.error(f"External volume permission check failed for {self._volume_name}: {e}") logger.error(f"External volume permission check failed for {self._volume_name}: {e}")
logger.info(f"External Volume permission check failed, but permission checking is disabled in this version") logger.info("External Volume permission check failed, but permission checking is disabled in this version")
return False return False
def _get_table_permissions(self, table_name: str) -> Set[str]: def _get_table_permissions(self, table_name: str) -> set[str]:
"""获取用户对指定表的权限 """获取用户对指定表的权限
Args: Args:
@ -236,14 +242,12 @@ class VolumePermissionManager:
object_name = grant[2] if len(grant) > 2 else "" object_name = grant[2] if len(grant) > 2 else ""
# 检查是否是对该表的权限 # 检查是否是对该表的权限
if object_type == "TABLE" and object_name == table_name: if (
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: object_type == "TABLE"
if privilege == "ALL": and object_name == table_name
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) or object_type == "SCHEMA"
else: and object_name in table_name
permissions.add(privilege) ):
# 检查是否是对整个schema的权限
elif object_type == "SCHEMA" and object_name in table_name:
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL": if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
@ -284,7 +288,7 @@ class VolumePermissionManager:
return "unknown" return "unknown"
def _get_user_permissions(self, username: str) -> Set[str]: def _get_user_permissions(self, username: str) -> set[str]:
"""获取用户的基本权限集合""" """获取用户的基本权限集合"""
cache_key = f"user_permissions:{username}" cache_key = f"user_permissions:{username}"
@ -321,7 +325,7 @@ class VolumePermissionManager:
self._permission_cache[cache_key] = permissions self._permission_cache[cache_key] = permissions
return permissions return permissions
def _get_external_volume_permissions(self, volume_name: str) -> Set[str]: def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
"""获取用户对指定External Volume的权限 """获取用户对指定External Volume的权限
Args: Args:
@ -363,10 +367,9 @@ class VolumePermissionManager:
) )
# 检查是否是对该Volume的权限或者是层级权限 # 检查是否是对该Volume的权限或者是层级权限
if ((granted_type == "PRIVILEGE" and granted_on == "VOLUME" and if (
object_name.endswith(volume_name)) or granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
(granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME")): ) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
logger.info(f"Matching grant found for {volume_name}") logger.info(f"Matching grant found for {volume_name}")
if "READ" in privilege: if "READ" in privilege:
@ -424,7 +427,7 @@ class VolumePermissionManager:
self._permission_cache.clear() self._permission_cache.clear()
logger.debug("Permission cache cleared") logger.debug("Permission cache cleared")
def get_permission_summary(self, dataset_id: Optional[str] = None) -> Dict[str, bool]: def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
"""获取权限摘要 """获取权限摘要
Args: Args:
@ -514,10 +517,16 @@ class VolumePermissionManager:
"""检查路径是否包含路径遍历攻击""" """检查路径是否包含路径遍历攻击"""
# 检查常见的路径遍历模式 # 检查常见的路径遍历模式
traversal_patterns = [ traversal_patterns = [
"../", "..\\", "../",
"..%2f", "..%2F", "..%5c", "..%5C", "..\\",
"%2e%2e%2f", "%2e%2e%5c", "..%2f",
"....//", "....\\\\", "..%2F",
"..%5c",
"..%5C",
"%2e%2e%2f",
"%2e%2e%5c",
"....//",
"....\\\\",
] ]
file_path_lower = file_path.lower() file_path_lower = file_path.lower()
@ -539,9 +548,21 @@ class VolumePermissionManager:
def _is_sensitive_path(self, file_path: str) -> bool: def _is_sensitive_path(self, file_path: str) -> bool:
"""检查路径是否为敏感路径""" """检查路径是否为敏感路径"""
sensitive_patterns = [ sensitive_patterns = [
"passwd", "shadow", "hosts", "config", "secrets", "passwd",
"private", "key", "certificate", "cert", "ssl", "shadow",
"database", "backup", "dump", "log", "tmp" "hosts",
"config",
"secrets",
"private",
"key",
"certificate",
"cert",
"ssl",
"database",
"backup",
"dump",
"log",
"tmp",
] ]
file_path_lower = file_path.lower() file_path_lower = file_path.lower()
@ -591,9 +612,9 @@ class VolumePermissionError(Exception):
super().__init__(message) super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager, def check_volume_permission(
operation: str, permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
dataset_id: Optional[str] = None) -> None: ) -> None:
"""权限检查装饰器函数 """权限检查装饰器函数
Args: Args:
@ -610,8 +631,5 @@ def check_volume_permission(permission_manager: VolumePermissionManager,
error_message += f" (dataset: {dataset_id})" error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError( raise VolumePermissionError(
error_message, error_message, operation=operation, volume_type=permission_manager._volume_type, dataset_id=dataset_id
operation=operation,
volume_type=permission_manager._volume_type,
dataset_id=dataset_id
) )

@ -3,7 +3,6 @@
import os import os
import tempfile import tempfile
import unittest import unittest
from unittest.mock import patch
import pytest import pytest
@ -15,7 +14,7 @@ from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
class TestClickZettaVolumeStorage(unittest.TestCase): class TestClickZettaVolumeStorage(unittest.TestCase):
"""Test cases for ClickZetta Volume Storage.""" """Test cases for ClickZetta Volume Storage."""
def setUp(self): def setUp(self):
"""Set up test environment.""" """Set up test environment."""
self.config = ClickZettaVolumeConfig( self.config = ClickZettaVolumeConfig(
@ -27,89 +26,83 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"), schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
volume_type="table", volume_type="table",
table_prefix="test_dataset_" table_prefix="test_dataset_",
) )
@pytest.mark.skipif( @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
not os.getenv("CLICKZETTA_USERNAME"),
reason="ClickZetta credentials not provided"
)
def test_user_volume_operations(self): def test_user_volume_operations(self):
"""Test basic operations with User Volume.""" """Test basic operations with User Volume."""
config = self.config config = self.config
config.volume_type = "user" config.volume_type = "user"
storage = ClickZettaVolumeStorage(config) storage = ClickZettaVolumeStorage(config)
# Test file operations # Test file operations
test_filename = "test_file.txt" test_filename = "test_file.txt"
test_content = b"Hello, ClickZetta Volume!" test_content = b"Hello, ClickZetta Volume!"
# Save file # Save file
storage.save(test_filename, test_content) storage.save(test_filename, test_content)
# Check if file exists # Check if file exists
self.assertTrue(storage.exists(test_filename)) self.assertTrue(storage.exists(test_filename))
# Load file # Load file
loaded_content = storage.load_once(test_filename) loaded_content = storage.load_once(test_filename)
self.assertEqual(loaded_content, test_content) self.assertEqual(loaded_content, test_content)
# Test streaming # Test streaming
stream_content = b"" stream_content = b""
for chunk in storage.load_stream(test_filename): for chunk in storage.load_stream(test_filename):
stream_content += chunk stream_content += chunk
self.assertEqual(stream_content, test_content) self.assertEqual(stream_content, test_content)
# Test download # Test download
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name) storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f: with open(temp_file.name, "rb") as f:
downloaded_content = f.read() downloaded_content = f.read()
self.assertEqual(downloaded_content, test_content) self.assertEqual(downloaded_content, test_content)
# Test scan # Test scan
files = storage.scan("", files=True, directories=False) files = storage.scan("", files=True, directories=False)
self.assertIn(test_filename, files) self.assertIn(test_filename, files)
# Delete file # Delete file
storage.delete(test_filename) storage.delete(test_filename)
self.assertFalse(storage.exists(test_filename)) self.assertFalse(storage.exists(test_filename))
@pytest.mark.skipif( @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
not os.getenv("CLICKZETTA_USERNAME"),
reason="ClickZetta credentials not provided"
)
def test_table_volume_operations(self): def test_table_volume_operations(self):
"""Test basic operations with Table Volume.""" """Test basic operations with Table Volume."""
config = self.config config = self.config
config.volume_type = "table" config.volume_type = "table"
storage = ClickZettaVolumeStorage(config) storage = ClickZettaVolumeStorage(config)
# Test file operations with dataset_id # Test file operations with dataset_id
dataset_id = "12345" dataset_id = "12345"
test_filename = f"{dataset_id}/test_file.txt" test_filename = f"{dataset_id}/test_file.txt"
test_content = b"Hello, Table Volume!" test_content = b"Hello, Table Volume!"
# Save file # Save file
storage.save(test_filename, test_content) storage.save(test_filename, test_content)
# Check if file exists # Check if file exists
self.assertTrue(storage.exists(test_filename)) self.assertTrue(storage.exists(test_filename))
# Load file # Load file
loaded_content = storage.load_once(test_filename) loaded_content = storage.load_once(test_filename)
self.assertEqual(loaded_content, test_content) self.assertEqual(loaded_content, test_content)
# Test scan for dataset # Test scan for dataset
files = storage.scan(dataset_id, files=True, directories=False) files = storage.scan(dataset_id, files=True, directories=False)
self.assertIn("test_file.txt", files) self.assertIn("test_file.txt", files)
# Delete file # Delete file
storage.delete(test_filename) storage.delete(test_filename)
self.assertFalse(storage.exists(test_filename)) self.assertFalse(storage.exists(test_filename))
def test_config_validation(self): def test_config_validation(self):
"""Test configuration validation.""" """Test configuration validation."""
# Test missing required fields # Test missing required fields
@ -119,56 +112,51 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
password="pass", password="pass",
instance="instance", instance="instance",
) )
# Test invalid volume type # Test invalid volume type
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ClickZettaVolumeConfig( ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
username="user",
password="pass",
instance="instance",
volume_type="invalid_type"
)
# Test external volume without volume_name # Test external volume without volume_name
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ClickZettaVolumeConfig( ClickZettaVolumeConfig(
username="user", username="user",
password="pass", password="pass",
instance="instance", instance="instance",
volume_type="external" volume_type="external",
# Missing volume_name # Missing volume_name
) )
def test_volume_path_generation(self): def test_volume_path_generation(self):
"""Test volume path generation for different types.""" """Test volume path generation for different types."""
storage = ClickZettaVolumeStorage(self.config) storage = ClickZettaVolumeStorage(self.config)
# Test table volume path # Test table volume path
path = storage._get_volume_path("test.txt", "12345") path = storage._get_volume_path("test.txt", "12345")
self.assertEqual(path, "test_dataset_12345/test.txt") self.assertEqual(path, "test_dataset_12345/test.txt")
# Test path with existing dataset_id prefix # Test path with existing dataset_id prefix
path = storage._get_volume_path("12345/test.txt") path = storage._get_volume_path("12345/test.txt")
self.assertEqual(path, "12345/test.txt") self.assertEqual(path, "12345/test.txt")
# Test user volume # Test user volume
storage._config.volume_type = "user" storage._config.volume_type = "user"
path = storage._get_volume_path("test.txt") path = storage._get_volume_path("test.txt")
self.assertEqual(path, "test.txt") self.assertEqual(path, "test.txt")
def test_sql_prefix_generation(self): def test_sql_prefix_generation(self):
"""Test SQL prefix generation for different volume types.""" """Test SQL prefix generation for different volume types."""
storage = ClickZettaVolumeStorage(self.config) storage = ClickZettaVolumeStorage(self.config)
# Test table volume SQL prefix # Test table volume SQL prefix
prefix = storage._get_volume_sql_prefix("12345") prefix = storage._get_volume_sql_prefix("12345")
self.assertEqual(prefix, "TABLE VOLUME test_dataset_12345") self.assertEqual(prefix, "TABLE VOLUME test_dataset_12345")
# Test user volume SQL prefix # Test user volume SQL prefix
storage._config.volume_type = "user" storage._config.volume_type = "user"
prefix = storage._get_volume_sql_prefix() prefix = storage._get_volume_sql_prefix()
self.assertEqual(prefix, "USER VOLUME") self.assertEqual(prefix, "USER VOLUME")
# Test external volume SQL prefix # Test external volume SQL prefix
storage._config.volume_type = "external" storage._config.volume_type = "external"
storage._config.volume_name = "my_external_volume" storage._config.volume_name = "my_external_volume"

Loading…
Cancel
Save