feat: refactor: update segment, abstract common methods

pull/21682/head
neatguycoding 7 months ago
parent 6abbc91c0c
commit 1cd77a9a17

@ -2127,174 +2127,367 @@ class SegmentService:
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
"""
Update a document segment with new content, keywords, and metadata.
This method handles both simple updates (content unchanged) and complex updates (content changed)
with proper indexing, vector updates, and error handling.
Args:
args: Update arguments containing new content, keywords, and settings
segment: The document segment to update
document: The parent document containing the segment
dataset: The dataset containing the document
Returns:
DocumentSegment: The updated segment object
Raises:
ValueError: If segment is currently being indexed or disabled segment cannot be updated
"""
# Check if segment is currently being indexed to prevent concurrent operations
indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is indexing, please try again later")
# Handle segment enable/disable state changes
if args.enabled is not None:
action = args.enabled
if segment.enabled != action:
if not action:
segment.enabled = action
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
# Disable segment and trigger index removal
cls._disable_segment(segment, indexing_cache_key)
return segment
# Validate that disabled segments cannot be updated unless being enabled
if not segment.enabled:
if args.enabled is not None:
if not args.enabled:
raise ValueError("Can't update disabled segment")
else:
raise ValueError("Can't update disabled segment")
try:
word_count_change = segment.word_count
# Track word count changes for document update
original_word_count = segment.word_count
content = args.content or segment.content
# Handle simple update case: content unchanged, only metadata updates
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
keyword_changed = False
if args.keywords:
if Counter(segment.keywords) != Counter(args.keywords):
segment.keywords = args.keywords
keyword_changed = True
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
db.session.commit()
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
cls._handle_simple_segment_update(args, segment, document, dataset, original_word_count)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
# calc embedding use tokens
if document.doc_form == "qa_model":
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle complex update case: content changed, requires re-indexing
cls._handle_complex_segment_update(args, segment, document, dataset, content, original_word_count)
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.status = "error"
segment.error = str(e)
db.session.commit()
# Handle update failures by marking segment as error state
cls._handle_segment_update_error(segment, e)
# Return fresh segment object from database
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
def _disable_segment(cls, segment: DocumentSegment, indexing_cache_key: str):
"""
Disable a segment and trigger index removal.
Args:
segment: The segment to disable
indexing_cache_key: Redis cache key for indexing lock
"""
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
@classmethod
def _handle_simple_segment_update(
cls,
args: SegmentUpdateArgs,
segment: DocumentSegment,
document: Document,
dataset: Dataset,
original_word_count: int,
):
"""
Handle segment update when content remains unchanged.
Only updates metadata like keywords, answer, and word count.
Args:
args: Update arguments
segment: The segment to update
document: Parent document
dataset: Parent dataset
original_word_count: Original word count before update
"""
# Update word count for main content
segment.word_count = len(segment.content)
# Handle QA model specific updates
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
# Calculate word count change for document update
word_count_change = segment.word_count - original_word_count
# Update keywords if provided and changed
keyword_changed = False
if args.keywords:
if Counter(segment.keywords) != Counter(args.keywords):
segment.keywords = args.keywords
keyword_changed = True
# Reset segment state to enabled
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
# Persist changes to database
db.session.add(segment)
db.session.commit()
# Update document word count if changed
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# Handle vector index updates based on document type
cls._handle_vector_index_updates(args, segment, document, dataset, keyword_changed, False)
@classmethod
def _handle_complex_segment_update(
cls,
args: SegmentUpdateArgs,
segment: DocumentSegment,
document: Document,
dataset: Dataset,
content: str,
original_word_count: int,
):
"""
Handle segment update when content has changed.
Requires re-indexing and vector updates.
Args:
args: Update arguments
segment: The segment to update
document: Parent document
dataset: Parent dataset
content: New content for the segment
original_word_count: Original word count before update
"""
# Generate new content hash for change detection
segment_hash = helper.generate_text_hash(content)
# Calculate tokens for high quality indexing
tokens = cls._calculate_segment_tokens(args, content, document, dataset)
# Update segment with new content and metadata
cls._update_segment_metadata(document, segment, content, segment_hash, tokens, args)
# Calculate word count change for document update
word_count_change = segment.word_count - original_word_count
# Update document word count if changed
if word_count_change != 0:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# Persist changes to database
db.session.add(segment)
db.session.commit()
# Handle vector index updates based on document type
cls._handle_vector_index_updates(args, segment, document, dataset, True, True)
@classmethod
def _calculate_segment_tokens(
cls, args: SegmentUpdateArgs, content: str, document: Document, dataset: Dataset
) -> int:
"""
Calculate token count for segment content based on indexing technique.
Args:
args: Update arguments
content: Segment content
document: Parent document
dataset: Parent dataset
Returns:
int: Token count for the content
"""
tokens = 0
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
# Calculate embedding tokens for QA model or regular content
if document.doc_form == "qa_model":
answer = args.answer or ""
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + answer])[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
return tokens
@classmethod
def _update_segment_metadata(
cls,
document: Document,
segment: DocumentSegment,
content: str,
segment_hash: str,
tokens: int,
args: SegmentUpdateArgs,
):
"""
Update segment metadata with new content and timestamps.
Args:
segment: The segment to update
content: New content
segment_hash: Content hash for change detection
tokens: Token count for the content
args: Update arguments
"""
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
# Update timestamps
current_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.indexing_at = current_time
segment.completed_at = current_time
segment.updated_by = current_user.id
segment.updated_at = current_time
# Reset segment state to enabled
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
# Handle QA model specific metadata
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
@classmethod
def _handle_vector_index_updates(
cls,
args: SegmentUpdateArgs,
segment: DocumentSegment,
document: Document,
dataset: Dataset,
keyword_changed: bool,
content_changed: bool,
):
"""
Handle vector index updates based on document type and update conditions.
Args:
args: Update arguments
segment: The segment to update
document: Parent document
dataset: Parent dataset
keyword_changed: Whether keywords were changed in this update
content_changed: Whether content was changed in this update
"""
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# Regenerate child chunks for parent-child indexing
cls._regenerate_child_chunks(segment, document, dataset)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
# Update segment vector for paragraph/QA indexing
if content_changed or args.enabled or keyword_changed:
VectorService.update_segment_vector(args.keywords, segment, dataset)
@classmethod
def _regenerate_child_chunks(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
"""
Regenerate child chunks for parent-child indexing.
Args:
segment: The segment to regenerate chunks for
document: Parent document
dataset: Parent dataset
Raises:
ValueError: If indexing technique is not high quality or processing rule not found
"""
if dataset.indexing_technique != "high_quality":
raise ValueError("The knowledge base index technique is not high quality!")
# Get embedding model instance
embedding_model_instance = cls._get_embedding_model_instance(dataset)
# Get processing rule for chunk generation
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
# Generate new child chunks
VectorService.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, True)
@classmethod
def _get_embedding_model_instance(cls, dataset: Dataset):
"""
Get embedding model instance for the dataset.
Args:
dataset: The dataset to get model for
Returns:
Model instance for text embedding
"""
model_manager = ModelManager()
if dataset.embedding_model_provider:
return model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
return model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
@classmethod
def _handle_segment_update_error(cls, segment: DocumentSegment, error: Exception):
"""
Handle errors during segment update by marking segment as error state.
Args:
segment: The segment that encountered an error
error: The exception that occurred
"""
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.status = "error"
segment.error = str(error)
db.session.commit()
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = "segment_{}_delete_indexing".format(segment.id)

Loading…
Cancel
Save