pull/21682/merge
NeatGuyCoding 7 months ago committed by GitHub
commit f9fe2060fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2128,174 +2128,367 @@ class SegmentService:
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
"""
Update a document segment with new content, keywords, and metadata.
This method handles both simple updates (content unchanged) and complex updates (content changed)
with proper indexing, vector updates, and error handling.
Args:
args: Update arguments containing new content, keywords, and settings
segment: The document segment to update
document: The parent document containing the segment
dataset: The dataset containing the document
Returns:
DocumentSegment: The updated segment object
Raises:
ValueError: If segment is currently being indexed or disabled segment cannot be updated
"""
# Check if segment is currently being indexed to prevent concurrent operations
indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is indexing, please try again later")
# Handle segment enable/disable state changes
if args.enabled is not None:
action = args.enabled
if segment.enabled != action:
if not action:
segment.enabled = action
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
# Disable segment and trigger index removal
cls._disable_segment(segment, indexing_cache_key)
return segment
# Validate that disabled segments cannot be updated unless being enabled
if not segment.enabled:
if args.enabled is not None:
if not args.enabled:
raise ValueError("Can't update disabled segment")
else:
raise ValueError("Can't update disabled segment")
try:
word_count_change = segment.word_count
# Track word count changes for document update
original_word_count = segment.word_count
content = args.content or segment.content
# Handle simple update case: content unchanged, only metadata updates
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
keyword_changed = False
if args.keywords:
if Counter(segment.keywords) != Counter(args.keywords):
segment.keywords = args.keywords
keyword_changed = True
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
db.session.commit()
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
cls._handle_simple_segment_update(args, segment, document, dataset, original_word_count)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
# calc embedding use tokens
if document.doc_form == "qa_model":
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle complex update case: content changed, requires re-indexing
cls._handle_complex_segment_update(args, segment, document, dataset, content, original_word_count)
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.status = "error"
segment.error = str(e)
db.session.commit()
# Handle update failures by marking segment as error state
cls._handle_segment_update_error(segment, e)
# Return fresh segment object from database
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
def _disable_segment(cls, segment: DocumentSegment, indexing_cache_key: str):
"""
Disable a segment and trigger index removal.
Args:
segment: The segment to disable
indexing_cache_key: Redis cache key for indexing lock
"""
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
@classmethod
def _handle_simple_segment_update(
cls,
args: SegmentUpdateArgs,
segment: DocumentSegment,
document: Document,
dataset: Dataset,
original_word_count: int,
):
"""
Handle segment update when content remains unchanged.
Only updates metadata like keywords, answer, and word count.
Args:
args: Update arguments
segment: The segment to update
document: Parent document
dataset: Parent dataset
original_word_count: Original word count before update
"""
# Update word count for main content
segment.word_count = len(segment.content)
# Handle QA model specific updates
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
# Calculate word count change for document update
word_count_change = segment.word_count - original_word_count
# Update keywords if provided and changed
keyword_changed = False
if args.keywords:
if Counter(segment.keywords) != Counter(args.keywords):
segment.keywords = args.keywords
keyword_changed = True
# Reset segment state to enabled
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
# Persist changes to database
db.session.add(segment)
db.session.commit()
# Update document word count if changed
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# Handle vector index updates based on document type
cls._handle_vector_index_updates(args, segment, document, dataset, keyword_changed, False)
@classmethod
def _handle_complex_segment_update(
cls,
args: SegmentUpdateArgs,
segment: DocumentSegment,
document: Document,
dataset: Dataset,
content: str,
original_word_count: int,
):
"""
Handle segment update when content has changed.
Requires re-indexing and vector updates.
Args:
args: Update arguments
segment: The segment to update
document: Parent document
dataset: Parent dataset
content: New content for the segment
original_word_count: Original word count before update
"""
# Generate new content hash for change detection
segment_hash = helper.generate_text_hash(content)
# Calculate tokens for high quality indexing
tokens = cls._calculate_segment_tokens(args, content, document, dataset)
# Update segment with new content and metadata
cls._update_segment_metadata(document, segment, content, segment_hash, tokens, args)
# Calculate word count change for document update
word_count_change = segment.word_count - original_word_count
# Update document word count if changed
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# Persist changes to database
db.session.add(segment)
db.session.commit()
# Handle vector index updates based on document type
cls._handle_vector_index_updates(args, segment, document, dataset, True, True)
@classmethod
def _calculate_segment_tokens(
cls, args: SegmentUpdateArgs, content: str, document: Document, dataset: Dataset
) -> int:
"""
Calculate token count for segment content based on indexing technique.
Args:
args: Update arguments
content: Segment content
document: Parent document
dataset: Parent dataset
Returns:
int: Token count for the content
"""
tokens = 0
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
# Calculate embedding tokens for QA model or regular content
if document.doc_form == "qa_model":
answer = args.answer or ""
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + answer])[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
return tokens
@classmethod
def _update_segment_metadata(
cls,
document: Document,
segment: DocumentSegment,
content: str,
segment_hash: str,
tokens: int,
args: SegmentUpdateArgs,
):
"""
Update segment metadata with new content and timestamps.
Args:
segment: The segment to update
content: New content
segment_hash: Content hash for change detection
tokens: Token count for the content
args: Update arguments
"""
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
# Update timestamps
current_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.indexing_at = current_time
segment.completed_at = current_time
segment.updated_by = current_user.id
segment.updated_at = current_time
# Reset segment state to enabled
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
# Handle QA model specific metadata
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
@classmethod
def _handle_vector_index_updates(
cls,
args: SegmentUpdateArgs,
segment: DocumentSegment,
document: Document,
dataset: Dataset,
keyword_changed: bool,
content_changed: bool,
):
"""
Handle vector index updates based on document type and update conditions.
Args:
args: Update arguments
segment: The segment to update
document: Parent document
dataset: Parent dataset
keyword_changed: Whether keywords were changed in this update
content_changed: Whether content was changed in this update
"""
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# Regenerate child chunks for parent-child indexing
cls._regenerate_child_chunks(segment, document, dataset)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
# Update segment vector for paragraph/QA indexing
if content_changed or args.enabled or keyword_changed:
VectorService.update_segment_vector(args.keywords, segment, dataset)
@classmethod
def _regenerate_child_chunks(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
"""
Regenerate child chunks for parent-child indexing.
Args:
segment: The segment to regenerate chunks for
document: Parent document
dataset: Parent dataset
Raises:
ValueError: If indexing technique is not high quality or processing rule not found
"""
if dataset.indexing_technique != "high_quality":
raise ValueError("The knowledge base index technique is not high quality!")
# Get embedding model instance
embedding_model_instance = cls._get_embedding_model_instance(dataset)
# Get processing rule for chunk generation
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
# Generate new child chunks
VectorService.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, True)
@classmethod
def _get_embedding_model_instance(cls, dataset: Dataset):
"""
Get embedding model instance for the dataset.
Args:
dataset: The dataset to get model for
Returns:
Model instance for text embedding
"""
model_manager = ModelManager()
if dataset.embedding_model_provider:
return model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
return model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
@classmethod
def _handle_segment_update_error(cls, segment: DocumentSegment, error: Exception):
"""
Handle errors during segment update by marking segment as error state.
Args:
segment: The segment that encountered an error
error: The exception that occurred
"""
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.status = "error"
segment.error = str(error)
db.session.commit()
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = "segment_{}_delete_indexing".format(segment.id)

@ -0,0 +1,872 @@
import datetime
from typing import Any, Optional
# Mock redis_client before importing segment_service
from unittest.mock import Mock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexType
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
from services.dataset_service import SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from tests.unit_tests.conftest import redis_mock
class SegmentUpdateTestDataFactory:
"""Factory class for creating test data and mock objects for segment update tests."""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
content: str = "old_content",
answer: Optional[str] = None,
keywords: Optional[list[str]] = None,
enabled: bool = True,
status: str = "completed",
word_count: int = 10,
tokens: int = 15,
position: int = 1,
**kwargs,
) -> Mock:
"""Create a mock segment with specified attributes."""
segment = Mock(spec=DocumentSegment)
segment.id = segment_id
segment.content = content
segment.answer = answer
segment.keywords = keywords or []
segment.enabled = enabled
segment.status = status
segment.word_count = word_count
segment.tokens = tokens
segment.position = position
segment.index_node_id = f"node-{segment_id}"
segment.index_node_hash = f"hash-{segment_id}"
segment.tenant_id = "tenant-123"
segment.dataset_id = "dataset-123"
segment.document_id = "document-123"
segment.created_by = "user-789"
segment.created_at = datetime.datetime(2023, 1, 1, 12, 0, 0)
segment.updated_by = None
segment.updated_at = datetime.datetime(2023, 1, 1, 12, 0, 0)
segment.indexing_at = datetime.datetime(2023, 1, 1, 12, 0, 0)
segment.completed_at = datetime.datetime(2023, 1, 1, 12, 0, 0)
segment.disabled_at = None
segment.disabled_by = None
segment.error = None
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
@staticmethod
def create_document_mock(
document_id: str = "document-123",
doc_form: str = IndexType.PARAGRAPH_INDEX,
word_count: int = 100,
dataset_process_rule_id: str = "rule-123",
**kwargs,
) -> Mock:
"""Create a mock document with specified attributes."""
document = Mock(spec=Document)
document.id = document_id
document.doc_form = doc_form
document.word_count = word_count
document.dataset_process_rule_id = dataset_process_rule_id
document.tenant_id = "tenant-123"
document.dataset_id = "dataset-123"
document.created_by = "user-789"
document.created_at = datetime.datetime(2023, 1, 1, 12, 0, 0)
for key, value in kwargs.items():
setattr(document, key, value)
return document
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
indexing_technique: str = "high_quality",
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.indexing_technique = indexing_technique
dataset.embedding_model_provider = embedding_model_provider
dataset.embedding_model = embedding_model
dataset.tenant_id = tenant_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.provider = provider
embedding_model.get_text_embedding_num_tokens.return_value = [20]
return embedding_model
@staticmethod
def create_processing_rule_mock(rule_id: str = "rule-123") -> Mock:
"""Create a mock processing rule."""
processing_rule = Mock(spec=DatasetProcessRule)
processing_rule.id = rule_id
processing_rule.to_dict.return_value = {"rules": {"parent_mode": "full_doc"}}
return processing_rule
@staticmethod
def create_current_user_mock(user_id: str = "user-789", tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user."""
current_user = Mock()
current_user.id = user_id
current_user.current_tenant_id = tenant_id
return current_user
class TestSegmentServiceUpdateSegment:
"""
Comprehensive unit tests for SegmentService.update_segment method.
This test suite covers all supported scenarios including:
- Segment enable/disable operations
- Content updates with same and different content
- QA model updates with answer field
- Keyword updates
- Different document forms (paragraph, QA, parent-child)
- High quality vs economy indexing techniques
- Child chunk regeneration
- Error handling and edge cases
- Redis cache management
- Vector index updates
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""Common mock setup for segment service dependencies."""
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.datetime") as mock_datetime,
patch("services.dataset_service.current_user") as mock_current_user,
patch("services.dataset_service.helper") as mock_helper,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
mock_current_user.id = "user-789"
mock_current_user.current_tenant_id = "tenant-123"
mock_helper.generate_text_hash.return_value = "new_hash_123"
yield {
"redis_client": mock_redis,
"db_session": mock_db,
"datetime": mock_datetime,
"current_user": mock_current_user,
"helper": mock_helper,
"current_time": current_time,
}
@pytest.fixture
def mock_model_manager_dependencies(self):
"""Mock setup for model manager tests."""
with patch("services.dataset_service.ModelManager") as mock_model_manager:
yield mock_model_manager
@pytest.fixture
def mock_vector_service_dependencies(self):
"""Mock setup for vector service tests."""
with (
patch("services.dataset_service.VectorService") as mock_vector_service,
patch("services.dataset_service.disable_segment_from_index_task") as mock_disable_task,
):
yield {
"vector_service": mock_vector_service,
"disable_task": mock_disable_task,
}
@pytest.fixture
def mock_processing_rule_dependencies(self):
"""Mock setup for processing rule tests."""
with patch("services.dataset_service.DatasetProcessRule") as mock_processing_rule:
yield mock_processing_rule
def _assert_redis_cache_operations(self, mock_redis, segment_id: str, should_set_cache: bool = False):
"""Helper method to verify Redis cache operations."""
mock_redis.get.assert_called_once_with(f"segment_{segment_id}_indexing")
if should_set_cache:
mock_redis.setex.assert_called_once_with(f"segment_{segment_id}_indexing", 600, 1)
def _assert_database_operations(self, mock_db, expected_objects: list[Any]):
"""Helper method to verify database operations.
Args:
mock_db: Mock database session
expected_objects: List of objects that should have been added to the session.
If None, only verifies that add() was called at least once.
"""
if expected_objects is None:
# Just verify that add was called at least once
assert mock_db.add.call_count >= 1
else:
# Get all the objects that were passed to add()
added_objects = [call.args[0] for call in mock_db.add.call_args_list]
# Verify each expected object was added
for expected_obj in expected_objects:
assert expected_obj in added_objects, f"Expected object {expected_obj} was not added to session"
def _assert_segment_attributes_updated(self, segment, expected_updates: dict[str, Any]):
"""Helper method to verify segment attribute updates."""
for key, value in expected_updates.items():
assert getattr(segment, key) == value
# ==================== Enable/Disable Segment Tests ====================
def test_disable_segment_success(self, mock_segment_service_dependencies, mock_vector_service_dependencies):
"""Test successful disable of an enabled segment."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=True)
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(enabled=False)
result = SegmentService.update_segment(args, segment, document, dataset)
# Verify Redis cache check
self._assert_redis_cache_operations(
mock_segment_service_dependencies["redis_client"], segment.id, should_set_cache=True
)
# Verify segment was disabled
self._assert_segment_attributes_updated(
segment,
{
"enabled": False,
"disabled_at": mock_segment_service_dependencies["current_time"].replace(tzinfo=None),
"disabled_by": mock_segment_service_dependencies["current_user"].id,
},
)
# Verify database operations
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
# Verify disable task was triggered
mock_vector_service_dependencies["disable_task"].delay.assert_called_once_with(segment.id)
# Verify return value
assert result == segment
def test_disable_segment_already_disabled(self, mock_segment_service_dependencies):
"""Test disable operation on already disabled segment."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=False)
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(enabled=False)
with pytest.raises(ValueError) as context:
SegmentService.update_segment(args, segment, document, dataset)
assert "Can't update disabled segment" in str(context.value)
def test_enable_segment_no_change(self, mock_segment_service_dependencies, mock_vector_service_dependencies):
"""Test enable operation on already enabled segment (no change)."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=True)
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(enabled=True)
result = SegmentService.update_segment(args, segment, document, dataset)
# Verify segment remains enabled (no change)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
def test_update_disabled_segment_without_enable(self, mock_segment_service_dependencies):
"""Test updating disabled segment without enabling it."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=False)
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="new_content")
with pytest.raises(ValueError) as context:
SegmentService.update_segment(args, segment, document, dataset)
assert "Can't update disabled segment" in str(context.value)
def test_segment_indexing_in_progress_error(self, mock_segment_service_dependencies):
"""Test error when segment is currently being indexed."""
segment = SegmentUpdateTestDataFactory.create_segment_mock()
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as indexing in progress
mock_segment_service_dependencies["redis_client"].get.return_value = "1"
args = SegmentUpdateArgs(content="new_content")
with pytest.raises(ValueError) as context:
SegmentService.update_segment(args, segment, document, dataset)
assert "Segment is indexing, please try again later" in str(context.value)
# ==================== Content Update Tests (Same Content) ====================
def test_update_segment_same_content_success(
self, mock_segment_service_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with same content (only keywords/answer change)."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(
content="test_content", keywords=["old_keyword"], answer="old_answer"
)
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(
content="test_content", # Same content
keywords=["new_keyword"],
answer="new_answer",
)
SegmentService.update_segment(args, segment, document, dataset)
# Verify segment attributes were updated
self._assert_segment_attributes_updated(
segment,
{
"keywords": ["new_keyword"],
"answer": "new_answer",
"enabled": True,
"disabled_at": None,
"disabled_by": None,
"word_count": len("test_content") + len("new_answer"),
},
)
# Verify database operations
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_same_content_keywords_unchanged(
self, mock_segment_service_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with same content and unchanged keywords."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(
content="test_content", keywords=["keyword1", "keyword2"]
)
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(
content="test_content", # Same content
keywords=["keyword1", "keyword2"], # Same keywords
)
SegmentService.update_segment(args, segment, document, dataset)
# Verify keywords remain unchanged
assert segment.keywords == ["keyword1", "keyword2"]
def test_update_segment_same_content_no_keywords_provided(
self, mock_segment_service_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with same content and no keywords provided."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(
content="test_content", keywords=["existing_keyword"]
)
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="test_content") # No keywords provided
SegmentService.update_segment(args, segment, document, dataset)
# Verify keywords remain unchanged
assert segment.keywords == ["existing_keyword"]
# ==================== Content Update Tests (Different Content) ====================
def test_update_segment_different_content_paragraph_index(
self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with different content for paragraph index."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content", word_count=10, tokens=15)
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARAGRAPH_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
args = SegmentUpdateArgs(content="new_content", keywords=["new_keyword"])
SegmentService.update_segment(args, segment, document, dataset)
# Verify segment attributes were updated
self._assert_segment_attributes_updated(
segment,
{
"content": "new_content",
"index_node_hash": "new_hash_123",
"word_count": len("new_content"),
"tokens": 20, # From mock embedding model
"status": "completed",
"keywords": [], # keywords are not updated for content update
"enabled": True,
"disabled_at": None,
"disabled_by": None,
},
)
# Verify database operations
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
# Verify vector service was called
mock_vector_service_dependencies["vector_service"].update_segment_vector.assert_called_once_with(
["new_keyword"], segment, dataset
)
def test_update_segment_different_content_qa_index(
self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with different content for QA index."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(
content="old_content", answer="old_answer", word_count=10, tokens=15
)
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
args = SegmentUpdateArgs(content="new_content", answer="new_answer", keywords=["new_keyword"])
SegmentService.update_segment(args, segment, document, dataset)
# Verify segment attributes were updated
expected_word_count = len("new_content") + len("new_answer")
self._assert_segment_attributes_updated(
segment,
{
"content": "new_content",
"answer": "new_answer",
"index_node_hash": "new_hash_123",
"word_count": expected_word_count,
"tokens": 20, # From mock embedding model
"status": "completed",
"keywords": [], # keywords are not updated for content update
"enabled": True,
"disabled_at": None,
"disabled_by": None,
},
)
# Verify embedding model was called with combined content
embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new_contentnew_answer"])
# Verify vector service was called
mock_vector_service_dependencies["vector_service"].update_segment_vector.assert_called_once_with(
["new_keyword"], segment, dataset
)
# Verify database operations - segment should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_different_content_economy_indexing(
self, mock_segment_service_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with different content for economy indexing technique."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content", word_count=10, tokens=15)
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARAGRAPH_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock(indexing_technique="economy")
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="new_content", keywords=["new_keyword"])
SegmentService.update_segment(args, segment, document, dataset)
# Verify segment attributes were updated (no tokens calculation for economy)
self._assert_segment_attributes_updated(
segment,
{
"content": "new_content",
"index_node_hash": "new_hash_123",
"word_count": len("new_content"),
"tokens": 0, # No tokens calculation for economy
"status": "completed",
"keywords": [], # keywords are not updated for content update
"enabled": True,
"disabled_at": None,
"disabled_by": None,
},
)
# Verify vector service was called
mock_vector_service_dependencies["vector_service"].update_segment_vector.assert_called_once_with(
["new_keyword"], segment, dataset
)
# Verify database operations - segment should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
# ==================== Parent-Child Index Tests ====================
def test_update_segment_parent_child_index_regenerate_child_chunks(
self,
mock_segment_service_dependencies,
mock_model_manager_dependencies,
mock_vector_service_dependencies,
mock_processing_rule_dependencies,
):
"""Test updating segment with parent-child index and child chunk regeneration."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content")
document = SegmentUpdateTestDataFactory.create_document_mock(
doc_form=IndexType.PARENT_CHILD_INDEX, dataset_process_rule_id="rule-123"
)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
# Mock processing rule query
processing_rule = SegmentUpdateTestDataFactory.create_processing_rule_mock()
mock_segment_service_dependencies[
"db_session"
].query.return_value.filter.return_value.first.return_value = processing_rule
args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=True, keywords=["new_keyword"])
SegmentService.update_segment(args, segment, document, dataset)
# Verify child chunks were regenerated
mock_vector_service_dependencies["vector_service"].generate_child_chunks.assert_called_once_with(
segment, document, dataset, embedding_model, processing_rule, True
)
# Verify database operations - segment should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_parent_child_index_no_regenerate(
self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies
):
"""Test updating segment with parent-child index without child chunk regeneration."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content")
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARENT_CHILD_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=False)
SegmentService.update_segment(args, segment, document, dataset)
# Verify child chunks were not regenerated
mock_vector_service_dependencies["vector_service"].generate_child_chunks.assert_not_called()
# Verify database operations - segment should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_parent_child_index_economy_technique_error(
self, mock_segment_service_dependencies, mock_model_manager_dependencies
):
"""Test error when trying to regenerate child chunks with economy indexing technique."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content")
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARENT_CHILD_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock(indexing_technique="economy")
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=True)
SegmentService.update_segment(args, segment, document, dataset)
assert "error" in segment.status
assert "The knowledge base index technique is not high quality!" in str(segment.error)
def test_update_segment_parent_child_index_no_processing_rule_error(
self, mock_segment_service_dependencies, mock_model_manager_dependencies
):
"""Test error when processing rule is not found for parent-child index."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content")
document = SegmentUpdateTestDataFactory.create_document_mock(
doc_form=IndexType.PARENT_CHILD_INDEX, dataset_process_rule_id="rule-123"
)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
# Mock processing rule query returning None
mock_segment_service_dependencies["db_session"].query.return_value.filter.return_value.first.return_value = None
args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=True)
SegmentService.update_segment(args, segment, document, dataset)
assert "error" in segment.status
assert "No processing rule found." in str(segment.error)
# ==================== Document Word Count Update Tests ====================
def test_update_segment_word_count_increase(
self, mock_segment_service_dependencies, mock_model_manager_dependencies
):
"""Test that document word count is updated when segment word count increases."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content", word_count=10)
document = SegmentUpdateTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="new_content_much_longer_than_old")
SegmentService.update_segment(args, segment, document, dataset)
# Verify document word count was updated
expected_word_count_change = len("new_content_much_longer_than_old") - 10
expected_document_word_count = 100 + expected_word_count_change
assert document.word_count == expected_document_word_count
# Verify database operations - both segment and document should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment, document])
def test_update_segment_word_count_decrease(
self, mock_segment_service_dependencies, mock_model_manager_dependencies
):
"""Test that document word count is updated when segment word count decreases."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="very_long_old_content", word_count=25)
document = SegmentUpdateTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="short")
SegmentService.update_segment(args, segment, document, dataset)
# Verify document word count was updated
expected_word_count_change = len("short") - 25
expected_document_word_count = 100 + expected_word_count_change
assert document.word_count == expected_document_word_count
# Verify database operations - both segment and document should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment, document])
def test_update_segment_word_count_no_change(self, mock_segment_service_dependencies):
"""Test that document word count is not updated when segment word count doesn't change."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="same_length", word_count=11)
document = SegmentUpdateTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="same_length") # Same length content
SegmentService.update_segment(args, segment, document, dataset)
# Verify document word count was not changed
assert document.word_count == 100
# Verify database operations - only segment should be added (no word count change for document)
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_word_count_qa_model_with_answer(
self, mock_segment_service_dependencies, mock_model_manager_dependencies
):
"""Test word count update for QA model with answer field."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(
content="question", answer="old_answer", word_count=7
)
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX, word_count=100)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
args = SegmentUpdateArgs(content="question", answer="new_longer_answer")
SegmentService.update_segment(args, segment, document, dataset)
# Verify segment word count includes answer
expected_segment_word_count = len("question") + len("new_longer_answer")
assert segment.word_count == expected_segment_word_count
# Verify document word count was updated
expected_word_count_change = expected_segment_word_count - 7
expected_document_word_count = 100 + expected_word_count_change
assert document.word_count == expected_document_word_count
# Verify database operations - both segment and document should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment, document])
# ==================== Error Handling Tests ====================
def test_update_segment_vector_service_error(
self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies
):
"""Test error handling when vector service fails."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content")
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock embedding model
embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock()
mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model
# Mock vector service to raise error
mock_vector_service_dependencies["vector_service"].update_segment_vector.side_effect = Exception(
"Vector service error"
)
args = SegmentUpdateArgs(content="new_content", keywords=["keyword"])
SegmentService.update_segment(args, segment, document, dataset)
# Verify segment was marked as error
self._assert_segment_attributes_updated(
segment,
{
"enabled": False,
"status": "error",
"error": "Vector service error",
"disabled_at": mock_segment_service_dependencies["current_time"].replace(tzinfo=None),
},
)
# Verify database operations - segment should be added with error state
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
# ==================== Edge Cases and Integration Tests ====================
def test_update_segment_with_none_content_uses_existing(self, mock_segment_service_dependencies):
"""Test that None content uses existing segment content."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="existing_content")
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content=None, keywords=["new_keyword"])
SegmentService.update_segment(args, segment, document, dataset)
# Verify content remains unchanged
assert segment.content == "existing_content"
# Verify keywords were updated
assert segment.keywords == ["new_keyword"]
# Verify database operations - segment should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_with_none_answer_qa_model(self, mock_segment_service_dependencies):
"""Test updating QA model segment with None answer."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(
content="question", answer="old_answer", word_count=15
)
document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX)
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
args = SegmentUpdateArgs(content="question", answer=None)
SegmentService.update_segment(args, segment, document, dataset)
# Verify answer was updated to None
assert segment.answer is None
# Verify word count only includes content (no answer)
assert segment.word_count == len("question")
# Verify database operations - segment should be added
self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment])
def test_update_segment_final_query_returns_updated_segment(self, mock_segment_service_dependencies):
"""Test that the final database query returns the updated segment."""
segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content")
document = SegmentUpdateTestDataFactory.create_document_mock()
dataset = SegmentUpdateTestDataFactory.create_dataset_mock()
# Mock Redis cache as not indexing
mock_segment_service_dependencies["redis_client"].get.return_value = None
# Mock final query to return updated segment
updated_segment = SegmentUpdateTestDataFactory.create_segment_mock(
segment_id="segment-123", content="new_content"
)
mock_segment_service_dependencies[
"db_session"
].query.return_value.filter.return_value.first.return_value = updated_segment
args = SegmentUpdateArgs(content="new_content")
mock_segment_service_dependencies[
"db_session"
].query.return_value.filter.return_value.first.return_value = updated_segment
result = SegmentService.update_segment(args, segment, document, dataset)
# Verify result is the updated segment
assert result == updated_segment
Loading…
Cancel
Save