feat: refactor: save_document_with_dataset_id for readability

pull/21528/head
neatguycoding 11 months ago
parent bad76a6f72
commit 2036d951fc

@ -1053,43 +1053,156 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
# check document limit
"""
Save documents to a dataset with comprehensive validation and processing.
This method handles document creation and updates for various data sources including
file uploads, Notion imports, and website crawling. It performs billing checks,
dataset configuration, and triggers indexing tasks.
Args:
dataset: The target dataset for document storage
knowledge_config: Configuration containing document metadata and processing rules
account: User account performing the operation
dataset_process_rule: Optional pre-existing process rule for the dataset
created_from: Source identifier for document creation (default: "web")
Returns:
tuple: (list of created/updated documents, batch identifier)
Raises:
ValueError: When billing limits are exceeded or configuration is invalid
FileNotExistsError: When referenced files are not found
"""
# Validate billing and upload limits for new documents
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if not knowledge_config.original_document_id:
if features.billing.enabled and not knowledge_config.original_document_id:
document_count = DocumentService._calculate_document_count(knowledge_config)
if document_count > 0:
DocumentService._validate_upload_limits(document_count, features)
# Initialize dataset configuration if not already set
DocumentService._initialize_dataset_configuration(dataset, knowledge_config)
documents = []
# Handle document update scenario
if knowledge_config.original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
# Handle new document creation scenario
batch = DocumentService._generate_batch_identifier()
# Create or retrieve process rule for document processing
dataset_process_rule = DocumentService._prepare_process_rule(
dataset, knowledge_config, account, dataset_process_rule
)
if not dataset_process_rule:
# keep the original process rule if no valid rule found, but it seems not completed
return
# Process documents with distributed lock to prevent race conditions
documents, batch = DocumentService._process_new_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch
)
return documents, batch
@staticmethod
def _calculate_document_count(knowledge_config: KnowledgeConfig) -> int:
"""
Calculate the total number of documents to be processed based on data source type.
Args:
knowledge_config: Configuration containing data source information
Returns:
int: Total count of documents to be processed
"""
count = 0
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
data_source_info = knowledge_config.data_source.info_list
if data_source_info.data_source_type == "upload_file":
upload_file_list = data_source_info.file_info_list.file_ids # type: ignore
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
elif data_source_info.data_source_type == "notion_import":
notion_info_list = data_source_info.notion_info_list
for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count += len(notion_info.pages)
elif data_source_info.data_source_type == "website_crawl":
website_info = data_source_info.website_info_list
count = len(website_info.urls) # type: ignore
return count
@staticmethod
def _validate_upload_limits(document_count: int, features: FeatureModel) -> None:
"""
Validate that the document upload operation complies with billing and subscription limits.
Args:
document_count: Number of documents to be uploaded
features: Feature model containing billing and subscription information
Raises:
ValueError: When upload limits are exceeded
"""
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
# Check sandbox plan restrictions
if features.billing.subscription.plan == "sandbox" and document_count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
# Check batch upload limit
if document_count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
DocumentService.check_documents_upload_quota(count, features)
# Check document quota
DocumentService.check_documents_upload_quota(document_count, features)
@staticmethod
def _initialize_dataset_configuration(dataset: Dataset, knowledge_config: KnowledgeConfig) -> None:
"""
Initialize dataset configuration settings if not already configured.
# if dataset is empty, update dataset data_source_type
Args:
dataset: Dataset to be configured
knowledge_config: Configuration containing dataset settings
Raises:
ValueError: When indexing technique is invalid
"""
# Set data source type if not already configured
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
# Configure indexing technique and related settings
if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
# Configure high-quality indexing settings
if knowledge_config.indexing_technique == "high_quality":
DocumentService._configure_high_quality_indexing(dataset, knowledge_config)
@staticmethod
def _configure_high_quality_indexing(dataset: Dataset, knowledge_config: KnowledgeConfig) -> None:
"""
Configure embedding model and retrieval settings for high-quality indexing.
Args:
dataset: Dataset to be configured
knowledge_config: Configuration containing model settings
"""
model_manager = ModelManager()
# Set embedding model configuration
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
@ -1099,12 +1212,17 @@ class DocumentService:
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
# Configure collection binding
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
# Configure retrieval model if not set
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -1120,17 +1238,45 @@ class DocumentService:
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
# save process rule
if not dataset_process_rule:
@staticmethod
def _generate_batch_identifier() -> str:
"""
Generate a unique batch identifier for document grouping.
Returns:
str: Unique batch identifier combining timestamp and random number
"""
return time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
@staticmethod
def _prepare_process_rule(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
) -> Optional[DatasetProcessRule]:
"""
Prepare or create dataset process rule for document processing.
Args:
dataset: Target dataset
knowledge_config: Configuration containing process rules
account: User account
dataset_process_rule: Optional existing process rule
Returns:
DatasetProcessRule: Prepared process rule for document processing
Raises:
ValueError: When no valid process rule can be found
"""
if dataset_process_rule:
return dataset_process_rule
process_rule = knowledge_config.process_rule
if process_rule:
if not process_rule:
return None
if process_rule.mode in ("custom", "hierarchical"):
if process_rule.rules:
dataset_process_rule = DatasetProcessRule(
@ -1151,37 +1297,165 @@ class DocumentService:
created_by=account.id,
)
else:
logging.warning(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
logging.warning(f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule")
return None
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
return dataset_process_rule
@staticmethod
def _process_new_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
) -> tuple[list, str]:
"""
Process new documents with distributed locking to prevent race conditions.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule for document processing
created_from: Source identifier
batch: Batch identifier
Returns:
tuple: (list of created documents, batch identifier)
"""
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
document_ids: list[str] = []
duplicate_document_ids: list[str] = []
documents: list[Document] = []
data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
# Process documents based on data source type
if data_source_type == "upload_file":
documents, document_ids, duplicate_document_ids = DocumentService._process_upload_file_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position
)
elif data_source_type == "notion_import":
documents, document_ids = DocumentService._process_notion_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position
)
elif data_source_type == "website_crawl":
documents, document_ids = DocumentService._process_website_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position
)
db.session.commit()
# Trigger asynchronous indexing tasks
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@staticmethod
def _process_upload_file_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
position: int,
) -> tuple[list[Document], list[str], list[str]]:
"""
Process uploaded file documents with duplicate checking and validation.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule
created_from: Source identifier
batch: Batch identifier
position: Starting position for document ordering
Returns:
tuple: (list of documents, list of document IDs, list of duplicate document IDs)
Raises:
FileNotExistsError: When referenced files are not found
"""
documents: list[Document] = []
document_ids: list[str] = []
duplicate_document_ids: list[str] = []
current_position = position
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list:
# Validate file existence
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
data_source_info = {"upload_file_id": file_id}
# Handle duplicate document processing
if knowledge_config.duplicate:
document = (
document = DocumentService._find_duplicate_document(dataset, file_name)
if document:
document = DocumentService._update_duplicate_document(
document, dataset_process_rule, knowledge_config, created_from, data_source_info, batch
)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
# Create new document
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
current_position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
current_position += 1
return documents, document_ids, duplicate_document_ids
@staticmethod
def _find_duplicate_document(dataset: Dataset, file_name: str) -> Optional[Document]:
"""
Find existing document with the same name in the dataset.
Args:
dataset: Target dataset
file_name: Name of the file to check for duplicates
Returns:
Document: Existing document if found, None otherwise
"""
return (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
@ -1192,7 +1466,30 @@ class DocumentService:
)
.first()
)
if document:
@staticmethod
def _update_duplicate_document(
document: Document,
dataset_process_rule: Optional[DatasetProcessRule],
knowledge_config: KnowledgeConfig,
created_from: str,
data_source_info: dict,
batch: str,
) -> Document:
"""
Update existing document with new configuration and mark for re-indexing.
Args:
document: Document to be updated
dataset_process_rule: Process rule to apply
knowledge_config: New configuration
created_from: Source identifier
data_source_info: Data source information
batch: Batch identifier
Returns:
Document: Updated document
"""
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from
@ -1202,9 +1499,66 @@ class DocumentService:
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
return document
@staticmethod
def _process_notion_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
position: int,
) -> tuple[list[Document], list[str]]:
"""
Process Notion import documents with workspace validation and page deduplication.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule
created_from: Source identifier
batch: Batch identifier
position: Starting position for document ordering
Returns:
tuple: (list of documents, list of document IDs)
Raises:
ValueError: When no notion info list is found or data source binding is missing
"""
documents: list[Document] = []
document_ids: list[str] = []
current_position = position
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
# Get existing Notion documents for deduplication
exist_page_ids, exist_document = DocumentService._get_existing_notion_documents(dataset)
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
# Validate data source binding
data_source_binding = DocumentService._validate_notion_binding(workspace_id)
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
# Create new Notion document
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"type": page.type,
}
# Truncate page name to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
@ -1213,23 +1567,40 @@ class DocumentService:
knowledge_config.doc_language,
data_source_info,
created_from,
position,
current_position,
account,
file_name,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents_from_db = (
current_position += 1
else:
exist_document.pop(page.page_id)
# Clean up unselected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
return documents, document_ids
@staticmethod
def _get_existing_notion_documents(dataset: Dataset) -> tuple[list[str], dict[str, str]]:
"""
Retrieve existing Notion documents for deduplication purposes.
Args:
dataset: Target dataset
Returns:
tuple: (list of existing page IDs, dict mapping page IDs to document IDs)
"""
exist_page_ids: list[str] = []
exist_document: dict[str, str] = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
@ -1239,13 +1610,29 @@ class DocumentService:
)
.all()
)
if documents_from_db:
for document in documents_from_db:
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
return exist_page_ids, exist_document
@staticmethod
def _validate_notion_binding(workspace_id: str) -> DataSourceOauthBinding:
"""
Validate Notion data source binding for the specified workspace.
Args:
workspace_id: Notion workspace identifier
Returns:
DataSourceOauthBinding: Valid binding for the workspace
Raises:
ValueError: When data source binding is not found
"""
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
@ -1258,46 +1645,50 @@ class DocumentService:
)
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
return data_source_binding
@staticmethod
def _process_website_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
position: int,
) -> tuple[list[Document], list[str]]:
"""
Process website crawl documents with URL validation and naming.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule
created_from: Source identifier
batch: Batch identifier
position: Starting position for document ordering
Returns:
tuple: (list of documents, list of document IDs)
Raises:
ValueError: When no website info list is found
"""
documents: list[Document] = []
document_ids: list[str] = []
current_position = position
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
@ -1306,10 +1697,10 @@ class DocumentService:
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
# Truncate URL for document naming if too long
document_name = url[:200] + "..." if len(url) > 255 else url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
@ -1318,7 +1709,7 @@ class DocumentService:
knowledge_config.doc_language,
data_source_info,
created_from,
position,
current_position,
account,
document_name,
batch,
@ -1327,16 +1718,9 @@ class DocumentService:
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
current_position += 1
return documents, batch
return documents, document_ids
@staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel):

Loading…
Cancel
Save