feat(api): Add `StorageKeyLoader` for loading _storage_key for `File`

pull/20699/head
QuantumGhost 11 months ago
parent c67a7abbe2
commit 2db7815098

@ -5,6 +5,7 @@ from typing import Any, cast
import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
@ -379,3 +380,74 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
def get_file_type_by_mime_type(mime_type: str) -> FileType: def get_file_type_by_mime_type(mime_type: str) -> FileType:
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
class StorageKeyLoader:
"""FileKeyLoader load the storage key from database for a list of files.
This loader is batched, the
"""
def __init__(self, session: Session, tenant_id: str) -> None:
self._session = session
self._tenant_id = tenant_id
def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
stmt = select(UploadFile).where(
UploadFile.id.in_(upload_file_ids),
UploadFile.tenant_id == self._tenant_id,
)
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
stmt = select(ToolFile).where(
ToolFile.id.in_(tool_file_ids),
ToolFile.tenant_id == self._tenant_id,
)
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
def load_storage_keys(self, files: Sequence[File]):
"""Loads storage keys for a sequence of files by retrieving the corresponding
`UploadFile` or `ToolFile` records from the database based on their transfer method.
This method doesn't modify the input sequence structure but updates the `_storage_key`
property of each file object by extracting the relevant key from its database record.
Performance note: This is a batched operation where database query count remains constant
regardless of input size. However, for optimal performance, input sequences should contain
fewer than 1000 files. For larger collections, split into smaller batches and process each
batch separately.
"""
upload_file_ids: list[uuid.UUID] = []
tool_file_ids: list[uuid.UUID] = []
for file in files:
if file.id is None:
raise ValueError("file id should not be None.")
if file.tenant_id != self._tenant_id:
err_msg = (
f"invalid file, expected tenant_id={self._tenant_id}, "
f"got tenant_id={file.tenant_id}, file_id={file.id}"
)
raise ValueError(err_msg)
file_id = uuid.UUID(file.id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_ids.append(file_id)
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_ids.append(file_id)
tool_files = self._load_tool_files(tool_file_ids)
upload_files = self._load_upload_files(upload_file_ids)
for file in files:
file_id = uuid.UUID(file.id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(file_id)
if upload_file_row is None:
raise ValueError(...)
file._storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(file_id)
if tool_file_row is None:
raise ValueError(...)
file._storage_key = tool_file_row.file_key

@ -0,0 +1,354 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@pytest.mark.usefixtures("flask_req_ctx")
class TestStorageKeyLoader(unittest.TestCase):
"""
Integration tests for StorageKeyLoader class.
Tests the batched loading of storage keys from the database for files
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
"""
def setUp(self):
"""Set up test data before each test method."""
self.session = db.session()
self.tenant_id = str(uuid4())
self.user_id = str(uuid4())
self.conversation_id = str(uuid4())
# Create test data that will be cleaned up after each test
self.test_upload_files = []
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)
def tearDown(self):
"""Clean up test data after each test method."""
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if storage_key is None:
storage_key = f"test_storage_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=storage_key,
name="test_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file.id = file_id
self.session.add(upload_file)
self.session.flush()
self.test_upload_files.append(upload_file)
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if file_key is None:
file_key = f"test_file_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)
return tool_file
def _create_file(self, file_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
related_id = None
remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
related_id = file_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
return File(
id=file_id,
tenant_id=tenant_id,
type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)
def test_load_storage_keys_local_file(self):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_remote_url(self):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.REMOTE_URL)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_tool_file(self):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == tool_file.file_key
def test_load_storage_keys_mixed_methods(self):
"""Test batch loading with mixed transfer methods."""
# Create test data for different transfer methods
upload_file1 = self._create_upload_file()
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()
file1 = self._create_file(upload_file1.id, FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(upload_file2.id, FileTransferMethod.REMOTE_URL)
file3 = self._create_file(tool_file.id, FileTransferMethod.TOOL_FILE)
files = [file1, file2, file3]
# Load storage keys
self.loader.load_storage_keys(files)
# Verify all storage keys were loaded correctly
assert file1._storage_key == upload_file1.key
assert file2._storage_key == upload_file2.key
assert file3._storage_key == tool_file.file_key
def test_load_storage_keys_empty_list(self):
"""Test with empty file list."""
# Should not raise any exceptions
self.loader.load_storage_keys([])
def test_load_storage_keys_tenant_mismatch(self):
"""Test tenant_id validation."""
# Create file with different tenant_id
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()))
# Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_missing_file_id(self):
"""Test with None file.id."""
# Create a file with valid parameters first, then manually set id to None
file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE)
file.id = None
# Should raise ValueError for None file id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert str(context.value) == "file id should not be None."
def test_load_storage_keys_nonexistent_upload_file_records(self):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(non_existent_id, FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_nonexistent_tool_file_records(self):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(non_existent_id, FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid id
file = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE)
file.id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_batch_efficiency(self):
"""Test batched operations use efficient queries."""
# Create multiple files of different types
upload_files = [self._create_upload_file() for _ in range(3)]
tool_files = [self._create_tool_file() for _ in range(2)]
files = []
files.extend([self._create_file(uf.id, FileTransferMethod.LOCAL_FILE) for uf in upload_files])
files.extend([self._create_file(tf.id, FileTransferMethod.TOOL_FILE) for tf in tool_files])
# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
self.loader.load_storage_keys(files)
# Should make exactly 2 queries (one for upload_files, one for tool_files)
assert mock_scalars.call_count == 2
# Verify all storage keys were loaded correctly
for i, file in enumerate(files[:3]):
assert file._storage_key == upload_files[i].key
for i, file in enumerate(files[3:]):
assert file._storage_key == tool_files[i].file_key
def test_load_storage_keys_tenant_isolation(self):
"""Test that tenant isolation works correctly."""
# Create files for different tenants
other_tenant_id = str(uuid4())
# Create upload file for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE)
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
key="other_tenant_key",
name="other_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file_other.id = str(uuid4())
self.session.add(upload_file_other)
self.session.flush()
# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(upload_file_other.id, FileTransferMethod.LOCAL_FILE, other_tenant_id)
# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_other])
assert "invalid file, expected tenant_id" in str(context.value)
# Current tenant's file should still work
self.loader.load_storage_keys([file_current])
assert file_current._storage_key == upload_file_current.key
def test_load_storage_keys_mixed_tenant_batch(self):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(upload_file_current.id, FileTransferMethod.LOCAL_FILE)
# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(str(uuid4()), FileTransferMethod.LOCAL_FILE, other_tenant_id)
# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_current, file_other])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_duplicate_file_ids(self):
"""Test handling of duplicate file IDs in the batch."""
# Create upload file
upload_file = self._create_upload_file()
# Create two File objects with same ID
file1 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])
# Both files should have the same storage key
assert file1._storage_key == upload_file.key
assert file2._storage_key == upload_file.key
def test_load_storage_keys_session_isolation(self):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(upload_file.id, FileTransferMethod.LOCAL_FILE)
# Create loader with different session (same underlying connection)
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])
Loading…
Cancel
Save