Fix remaining Python style and linting issues

- Fix line length violation in middleware config description
- Fix RUF013 type annotation to use union syntax
- Complete all Python style and linting fixes for CI checks
- Resolve formatter and linter warnings

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22551/head
yunqiqiliang 10 months ago
parent c3851595d0
commit b5a3f1d5e0

File diff suppressed because it is too large Load Diff

@ -64,8 +64,8 @@ class StorageConfig(BaseSettings):
"local",
] = Field(
description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'clickzetta-volume', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'clickzetta-volume', "
"'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
)

@ -67,4 +67,3 @@ class ClickzettaConfig(BaseModel):
description="Distance function for vector similarity: l2_distance or cosine_distance",
default="cosine_distance",
)

@ -16,6 +16,7 @@ import clickzetta # type: ignore[import]
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage
from .volume_permissions import VolumePermissionManager, check_volume_permission
logger = logging.getLogger(__name__)
@ -48,9 +49,9 @@ class ClickZettaVolumeConfig(BaseModel):
import os
# Helper function to get environment variable with fallback
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str = None) -> str:
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
# First try CLICKZETTA_VOLUME_* specific config
volume_value = values.get(volume_key.lower().replace('clickzetta_volume_', ''))
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
if volume_value:
return volume_value
@ -67,20 +68,19 @@ class ClickZettaVolumeConfig(BaseModel):
return default
# Apply environment variables with fallback to existing CLICKZETTA_* config
values.setdefault("username", get_env_with_fallback(
"CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
values.setdefault("password", get_env_with_fallback(
"CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
values.setdefault("instance", get_env_with_fallback(
"CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
values.setdefault("service", get_env_with_fallback(
"CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com"))
values.setdefault("workspace", get_env_with_fallback(
"CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start"))
values.setdefault("vcluster", get_env_with_fallback(
"CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap"))
values.setdefault("schema_name", get_env_with_fallback(
"CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
values.setdefault(
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
)
values.setdefault(
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
)
values.setdefault(
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
)
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
# Volume-specific configurations (no fallback to vector DB config)
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
@ -136,7 +136,7 @@ class ClickZettaVolumeStorage(BaseStorage):
service=self._config.service,
workspace=self._config.workspace,
vcluster=self._config.vcluster,
schema=self._config.schema_name
schema=self._config.schema_name,
)
logger.debug("ClickZetta connection established")
except Exception as e:
@ -147,9 +147,7 @@ class ClickZettaVolumeStorage(BaseStorage):
"""Initialize permission manager."""
try:
self._permission_manager = VolumePermissionManager(
self._connection,
self._config.volume_type,
self._config.volume_name
self._connection, self._config.volume_type, self._config.volume_name
)
logger.debug("Permission manager initialized")
except Exception as e:
@ -264,7 +262,7 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):]
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
@ -291,7 +289,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
actual_filename = volume_path.split('/')[-1] if '/' in volume_path else volume_path
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
@ -319,7 +317,7 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):]
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
@ -410,7 +408,7 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):]
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
@ -451,7 +449,7 @@ class ClickZettaVolumeStorage(BaseStorage):
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix):]
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
@ -514,11 +512,9 @@ class ClickZettaVolumeStorage(BaseStorage):
# For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash):] # Remove prefix
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
if files and not file_path.endswith("/"):
result.append(file_path)
elif directories and file_path.endswith("/"):
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
logger.debug(f"Scanned {len(result)} items in path {path}")

@ -6,17 +6,17 @@
import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict
from enum import Enum
from typing import Optional
logger = logging.getLogger(__name__)
class FileStatus(Enum):
"""文件状态枚举"""
ACTIVE = "active" # 活跃状态
ARCHIVED = "archived" # 已归档
DELETED = "deleted" # 已删除(软删除)
@ -26,6 +26,7 @@ class FileStatus(Enum):
@dataclass
class FileMetadata:
"""文件元数据"""
filename: str
size: int
created_at: datetime
@ -33,24 +34,24 @@ class FileMetadata:
version: int
status: FileStatus
checksum: Optional[str] = None
tags: Optional[Dict[str, str]] = None
tags: Optional[dict[str, str]] = None
parent_version: Optional[int] = None
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
"""转换为字典格式"""
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
data['modified_at'] = self.modified_at.isoformat()
data['status'] = self.status.value
data["created_at"] = self.created_at.isoformat()
data["modified_at"] = self.modified_at.isoformat()
data["status"] = self.status.value
return data
@classmethod
def from_dict(cls, data: Dict) -> 'FileMetadata':
def from_dict(cls, data: dict) -> "FileMetadata":
"""从字典创建实例"""
data = data.copy()
data['created_at'] = datetime.fromisoformat(data['created_at'])
data['modified_at'] = datetime.fromisoformat(data['modified_at'])
data['status'] = FileStatus(data['status'])
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
data["status"] = FileStatus(data["status"])
return cls(**data)
@ -72,10 +73,9 @@ class FileLifecycleManager:
self._deleted_prefix = ".deleted/"
# 获取权限管理器(如果存在)
self._permission_manager = getattr(storage, '_permission_manager', None)
self._permission_manager = getattr(storage, "_permission_manager", None)
def save_with_lifecycle(self, filename: str, data: bytes,
tags: Optional[Dict[str, str]] = None) -> FileMetadata:
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
"""保存文件并管理生命周期
Args:
@ -89,11 +89,12 @@ class FileLifecycleManager:
# 权限检查
if not self._check_permission(filename, "save"):
from .volume_permissions import VolumePermissionError
raise VolumePermissionError(
f"Permission denied for lifecycle save operation on file: {filename}",
operation="save",
volume_type=getattr(self._storage, '_config', {}).get('volume_type', 'unknown'),
dataset_id=self._dataset_id
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
dataset_id=self._dataset_id,
)
try:
@ -108,7 +109,7 @@ class FileLifecycleManager:
# 3. 计算文件信息
now = datetime.now()
checksum = self._calculate_checksum(data)
new_version = (current_metadata['version'] + 1) if current_metadata else 1
new_version = (current_metadata["version"] + 1) if current_metadata else 1
# 4. 保存新文件
self._storage.save(filename, data)
@ -119,11 +120,11 @@ class FileLifecycleManager:
if current_metadata:
# 如果created_at是字符串转换为datetime
if isinstance(current_metadata['created_at'], str):
created_at = datetime.fromisoformat(current_metadata['created_at'])
if isinstance(current_metadata["created_at"], str):
created_at = datetime.fromisoformat(current_metadata["created_at"])
else:
created_at = current_metadata['created_at']
parent_version = current_metadata['version']
created_at = current_metadata["created_at"]
parent_version = current_metadata["version"]
file_metadata = FileMetadata(
filename=filename,
@ -134,7 +135,7 @@ class FileLifecycleManager:
status=FileStatus.ACTIVE,
checksum=checksum,
tags=tags or {},
parent_version=parent_version
parent_version=parent_version,
)
# 6. 更新元数据
@ -166,7 +167,7 @@ class FileLifecycleManager:
logger.error(f"Failed to get file metadata for {filename}: {e}")
return None
def list_file_versions(self, filename: str) -> List[FileMetadata]:
def list_file_versions(self, filename: str) -> list[FileMetadata]:
"""列出文件的所有版本
Args:
@ -261,8 +262,8 @@ class FileLifecycleManager:
logger.warning(f"File {filename} not found in metadata")
return False
metadata_dict[filename]['status'] = FileStatus.ARCHIVED.value
metadata_dict[filename]['modified_at'] = datetime.now().isoformat()
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
@ -306,8 +307,8 @@ class FileLifecycleManager:
# 更新元数据
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]['status'] = FileStatus.DELETED.value
metadata_dict[filename]['modified_at'] = datetime.now().isoformat()
metadata_dict[filename]["status"] = FileStatus.DELETED.value
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info(f"File {filename} soft deleted successfully")
@ -340,7 +341,7 @@ class FileLifecycleManager:
file_versions = {}
for version_file in version_files:
# 解析文件名和版本
parts = version_file[len(self._version_prefix):].split(".v")
parts = version_file[len(self._version_prefix) :].split(".v")
if len(parts) >= 2:
base_filename = parts[0]
version_part = parts[1].split(".")[0]
@ -376,7 +377,7 @@ class FileLifecycleManager:
logger.error(f"Failed to cleanup old versions: {e}")
return 0
def get_storage_statistics(self) -> Dict[str, any]:
def get_storage_statistics(self) -> dict[str, any]:
"""获取存储统计信息
Returns:
@ -393,7 +394,7 @@ class FileLifecycleManager:
"total_size": 0,
"versions_count": 0,
"oldest_file": None,
"newest_file": None
"newest_file": None,
}
oldest_date = None
@ -431,7 +432,7 @@ class FileLifecycleManager:
logger.error(f"Failed to get storage statistics: {e}")
return {}
def _create_version_backup(self, filename: str, metadata: Dict):
def _create_version_backup(self, filename: str, metadata: dict):
"""创建版本备份"""
try:
# 读取当前文件内容
@ -446,23 +447,23 @@ class FileLifecycleManager:
except Exception as e:
logger.warning(f"Failed to create version backup for {filename}: {e}")
def _load_metadata(self) -> Dict:
def _load_metadata(self) -> dict:
"""加载元数据文件"""
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
return json.loads(metadata_content.decode('utf-8'))
return json.loads(metadata_content.decode("utf-8"))
else:
return {}
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
return {}
def _save_metadata(self, metadata_dict: Dict):
def _save_metadata(self, metadata_dict: dict):
"""保存元数据文件"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
self._storage.save(self._metadata_file, metadata_content.encode('utf-8'))
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
logger.debug("Metadata saved successfully")
except Exception as e:
logger.error(f"Failed to save metadata: {e}")
@ -471,6 +472,7 @@ class FileLifecycleManager:
def _calculate_checksum(self, data: bytes) -> str:
"""计算文件校验和"""
import hashlib
return hashlib.md5(data).hexdigest()
def _check_permission(self, filename: str, operation: str) -> bool:
@ -497,7 +499,7 @@ class FileLifecycleManager:
"restore": "save", # 恢复需要写权限
"cleanup": "delete", # 清理需要删除权限
"read": "load_once",
"write": "save"
"write": "save",
}
mapped_operation = operation_mapping.get(operation, operation)

@ -6,13 +6,14 @@
import logging
from enum import Enum
from typing import Dict, Optional, Set
from typing import Optional
logger = logging.getLogger(__name__)
class VolumePermission(Enum):
"""Volume权限类型枚举"""
READ = "SELECT" # 对应ClickZetta的SELECT权限
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
LIST = "SELECT" # 列出文件需要SELECT权限
@ -35,18 +36,19 @@ class VolumePermissionManager:
if isinstance(connection_or_config, dict):
# 从配置字典创建连接
import clickzetta
config = connection_or_config
self._connection = clickzetta.connect(
username=config.get('username'),
password=config.get('password'),
instance=config.get('instance'),
service=config.get('service'),
workspace=config.get('workspace'),
vcluster=config.get('vcluster'),
schema=config.get('schema') or config.get('database')
username=config.get("username"),
password=config.get("password"),
instance=config.get("instance"),
service=config.get("service"),
workspace=config.get("workspace"),
vcluster=config.get("vcluster"),
schema=config.get("schema") or config.get("database"),
)
self._volume_type = config.get('volume_type', volume_type)
self._volume_name = config.get('volume_name', volume_name)
self._volume_type = config.get("volume_type", volume_type)
self._volume_name = config.get("volume_name", volume_name)
else:
# 直接使用连接对象
self._connection = connection_or_config
@ -58,7 +60,7 @@ class VolumePermissionManager:
if not self._volume_type:
raise ValueError("volume_type is required")
self._permission_cache: Dict[str, Set[str]] = {}
self._permission_cache: dict[str, set[str]] = {}
self._current_username = None # 将从连接中获取当前用户名
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
@ -119,7 +121,7 @@ class VolumePermissionManager:
except Exception as e:
logger.error(f"User Volume permission check failed: {e}")
# 对于User Volume如果权限检查失败可能是配置问题给出更友好的错误提示
logger.info(f"User Volume permission check failed, but permission checking is disabled in this version")
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
return False
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
@ -144,8 +146,10 @@ class VolumePermissionManager:
# 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions)
logger.debug(f"Table Volume permission check for {table_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}")
logger.debug(
f"Table Volume permission check for {table_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}"
)
return has_permission
@ -180,8 +184,10 @@ class VolumePermissionManager:
# 检查是否有所需的所有权限
has_permission = required_permissions.issubset(permissions)
logger.debug(f"External Volume permission check for {self._volume_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}")
logger.debug(
f"External Volume permission check for {self._volume_name}, operation {operation.name}: "
f"required={required_permissions}, has={permissions}, granted={has_permission}"
)
# 如果权限检查失败,尝试备选验证
if not has_permission:
@ -203,10 +209,10 @@ class VolumePermissionManager:
except Exception as e:
logger.error(f"External volume permission check failed for {self._volume_name}: {e}")
logger.info(f"External Volume permission check failed, but permission checking is disabled in this version")
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
return False
def _get_table_permissions(self, table_name: str) -> Set[str]:
def _get_table_permissions(self, table_name: str) -> set[str]:
"""获取用户对指定表的权限
Args:
@ -236,14 +242,12 @@ class VolumePermissionManager:
object_name = grant[2] if len(grant) > 2 else ""
# 检查是否是对该表的权限
if object_type == "TABLE" and object_name == table_name:
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
else:
permissions.add(privilege)
# 检查是否是对整个schema的权限
elif object_type == "SCHEMA" and object_name in table_name:
if (
object_type == "TABLE"
and object_name == table_name
or object_type == "SCHEMA"
and object_name in table_name
):
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
@ -284,7 +288,7 @@ class VolumePermissionManager:
return "unknown"
def _get_user_permissions(self, username: str) -> Set[str]:
def _get_user_permissions(self, username: str) -> set[str]:
"""获取用户的基本权限集合"""
cache_key = f"user_permissions:{username}"
@ -321,7 +325,7 @@ class VolumePermissionManager:
self._permission_cache[cache_key] = permissions
return permissions
def _get_external_volume_permissions(self, volume_name: str) -> Set[str]:
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
"""获取用户对指定External Volume的权限
Args:
@ -363,10 +367,9 @@ class VolumePermissionManager:
)
# 检查是否是对该Volume的权限或者是层级权限
if ((granted_type == "PRIVILEGE" and granted_on == "VOLUME" and
object_name.endswith(volume_name)) or
(granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME")):
if (
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
logger.info(f"Matching grant found for {volume_name}")
if "READ" in privilege:
@ -424,7 +427,7 @@ class VolumePermissionManager:
self._permission_cache.clear()
logger.debug("Permission cache cleared")
def get_permission_summary(self, dataset_id: Optional[str] = None) -> Dict[str, bool]:
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
"""获取权限摘要
Args:
@ -514,10 +517,16 @@ class VolumePermissionManager:
"""检查路径是否包含路径遍历攻击"""
# 检查常见的路径遍历模式
traversal_patterns = [
"../", "..\\",
"..%2f", "..%2F", "..%5c", "..%5C",
"%2e%2e%2f", "%2e%2e%5c",
"....//", "....\\\\",
"../",
"..\\",
"..%2f",
"..%2F",
"..%5c",
"..%5C",
"%2e%2e%2f",
"%2e%2e%5c",
"....//",
"....\\\\",
]
file_path_lower = file_path.lower()
@ -539,9 +548,21 @@ class VolumePermissionManager:
def _is_sensitive_path(self, file_path: str) -> bool:
"""检查路径是否为敏感路径"""
sensitive_patterns = [
"passwd", "shadow", "hosts", "config", "secrets",
"private", "key", "certificate", "cert", "ssl",
"database", "backup", "dump", "log", "tmp"
"passwd",
"shadow",
"hosts",
"config",
"secrets",
"private",
"key",
"certificate",
"cert",
"ssl",
"database",
"backup",
"dump",
"log",
"tmp",
]
file_path_lower = file_path.lower()
@ -591,9 +612,9 @@ class VolumePermissionError(Exception):
super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager,
operation: str,
dataset_id: Optional[str] = None) -> None:
def check_volume_permission(
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
) -> None:
"""权限检查装饰器函数
Args:
@ -610,8 +631,5 @@ def check_volume_permission(permission_manager: VolumePermissionManager,
error_message += f" (dataset: {dataset_id})"
raise VolumePermissionError(
error_message,
operation=operation,
volume_type=permission_manager._volume_type,
dataset_id=dataset_id
error_message, operation=operation, volume_type=permission_manager._volume_type, dataset_id=dataset_id
)

@ -3,7 +3,6 @@
import os
import tempfile
import unittest
from unittest.mock import patch
import pytest
@ -27,13 +26,10 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
volume_type="table",
table_prefix="test_dataset_"
table_prefix="test_dataset_",
)
@pytest.mark.skipif(
not os.getenv("CLICKZETTA_USERNAME"),
reason="ClickZetta credentials not provided"
)
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
def test_user_volume_operations(self):
"""Test basic operations with User Volume."""
config = self.config
@ -76,10 +72,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
storage.delete(test_filename)
self.assertFalse(storage.exists(test_filename))
@pytest.mark.skipif(
not os.getenv("CLICKZETTA_USERNAME"),
reason="ClickZetta credentials not provided"
)
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
def test_table_volume_operations(self):
"""Test basic operations with Table Volume."""
config = self.config
@ -122,12 +115,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
# Test invalid volume type
with self.assertRaises(ValueError):
ClickZettaVolumeConfig(
username="user",
password="pass",
instance="instance",
volume_type="invalid_type"
)
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
# Test external volume without volume_name
with self.assertRaises(ValueError):
@ -135,7 +123,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
username="user",
password="pass",
instance="instance",
volume_type="external"
volume_type="external",
# Missing volume_name
)

Loading…
Cancel
Save