feat: refactor: save_document_with_dataset_id for readability

pull/21528/head
neatguycoding 11 months ago
parent bad76a6f72
commit 2036d951fc

@ -1053,43 +1053,156 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web", created_from: str = "web",
): ):
# check document limit """
Save documents to a dataset with comprehensive validation and processing.
This method handles document creation and updates for various data sources including
file uploads, Notion imports, and website crawling. It performs billing checks,
dataset configuration, and triggers indexing tasks.
Args:
dataset: The target dataset for document storage
knowledge_config: Configuration containing document metadata and processing rules
account: User account performing the operation
dataset_process_rule: Optional pre-existing process rule for the dataset
created_from: Source identifier for document creation (default: "web")
Returns:
tuple: (list of created/updated documents, batch identifier)
Raises:
ValueError: When billing limits are exceeded or configuration is invalid
FileNotExistsError: When referenced files are not found
"""
# Validate billing and upload limits for new documents
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled and not knowledge_config.original_document_id:
if not knowledge_config.original_document_id: document_count = DocumentService._calculate_document_count(knowledge_config)
if document_count > 0:
DocumentService._validate_upload_limits(document_count, features)
# Initialize dataset configuration if not already set
DocumentService._initialize_dataset_configuration(dataset, knowledge_config)
documents = []
# Handle document update scenario
if knowledge_config.original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
# Handle new document creation scenario
batch = DocumentService._generate_batch_identifier()
# Create or retrieve process rule for document processing
dataset_process_rule = DocumentService._prepare_process_rule(
dataset, knowledge_config, account, dataset_process_rule
)
if not dataset_process_rule:
# keep the original process rule if no valid rule found, but it seems not completed
return
# Process documents with distributed lock to prevent race conditions
documents, batch = DocumentService._process_new_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch
)
return documents, batch
@staticmethod
def _calculate_document_count(knowledge_config: KnowledgeConfig) -> int:
"""
Calculate the total number of documents to be processed based on data source type.
Args:
knowledge_config: Configuration containing data source information
Returns:
int: Total count of documents to be processed
"""
count = 0 count = 0
if knowledge_config.data_source: if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file": data_source_info = knowledge_config.data_source.info_list
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
if data_source_info.data_source_type == "upload_file":
upload_file_list = data_source_info.file_info_list.file_ids # type: ignore
count = len(upload_file_list) count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": elif data_source_info.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list notion_info_list = data_source_info.notion_info_list
for notion_info in notion_info_list: # type: ignore for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages) count += len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": elif data_source_info.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list website_info = data_source_info.website_info_list
count = len(website_info.urls) # type: ignore count = len(website_info.urls) # type: ignore
return count
@staticmethod
def _validate_upload_limits(document_count: int, features: FeatureModel) -> None:
"""
Validate that the document upload operation complies with billing and subscription limits.
Args:
document_count: Number of documents to be uploaded
features: Feature model containing billing and subscription information
Raises:
ValueError: When upload limits are exceeded
"""
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1: # Check sandbox plan restrictions
if features.billing.subscription.plan == "sandbox" and document_count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
# Check batch upload limit
if document_count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
DocumentService.check_documents_upload_quota(count, features) # Check document quota
DocumentService.check_documents_upload_quota(document_count, features)
@staticmethod
def _initialize_dataset_configuration(dataset: Dataset, knowledge_config: KnowledgeConfig) -> None:
"""
Initialize dataset configuration settings if not already configured.
# if dataset is empty, update dataset data_source_type Args:
dataset: Dataset to be configured
knowledge_config: Configuration containing dataset settings
Raises:
ValueError: When indexing technique is invalid
"""
# Set data source type if not already configured
if not dataset.data_source_type: if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
# Configure indexing technique and related settings
if not dataset.indexing_technique: if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid") raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique dataset.indexing_technique = knowledge_config.indexing_technique
# Configure high-quality indexing settings
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
DocumentService._configure_high_quality_indexing(dataset, knowledge_config)
@staticmethod
def _configure_high_quality_indexing(dataset: Dataset, knowledge_config: KnowledgeConfig) -> None:
"""
Configure embedding model and retrieval settings for high-quality indexing.
Args:
dataset: Dataset to be configured
knowledge_config: Configuration containing model settings
"""
model_manager = ModelManager() model_manager = ModelManager()
# Set embedding model configuration
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider dataset_embedding_model_provider = knowledge_config.embedding_model_provider
@ -1099,12 +1212,17 @@ class DocumentService:
) )
dataset_embedding_model = embedding_model.model dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider dataset.embedding_model_provider = dataset_embedding_model_provider
# Configure collection binding
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model dataset_embedding_model_provider, dataset_embedding_model
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
# Configure retrieval model if not set
if not dataset.retrieval_model: if not dataset.retrieval_model:
default_retrieval_model = { default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -1120,17 +1238,45 @@ class DocumentService:
else default_retrieval_model else default_retrieval_model
) # type: ignore ) # type: ignore
documents = [] @staticmethod
if knowledge_config.original_document_id: def _generate_batch_identifier() -> str:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) """
documents.append(document) Generate a unique batch identifier for document grouping.
batch = document.batch
else: Returns:
batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000)) str: Unique batch identifier combining timestamp and random number
# save process rule """
if not dataset_process_rule: return time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
@staticmethod
def _prepare_process_rule(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
) -> Optional[DatasetProcessRule]:
"""
Prepare or create dataset process rule for document processing.
Args:
dataset: Target dataset
knowledge_config: Configuration containing process rules
account: User account
dataset_process_rule: Optional existing process rule
Returns:
DatasetProcessRule: Prepared process rule for document processing
Raises:
ValueError: When no valid process rule can be found
"""
if dataset_process_rule:
return dataset_process_rule
process_rule = knowledge_config.process_rule process_rule = knowledge_config.process_rule
if process_rule: if not process_rule:
return None
if process_rule.mode in ("custom", "hierarchical"): if process_rule.mode in ("custom", "hierarchical"):
if process_rule.rules: if process_rule.rules:
dataset_process_rule = DatasetProcessRule( dataset_process_rule = DatasetProcessRule(
@ -1151,37 +1297,165 @@ class DocumentService:
created_by=account.id, created_by=account.id,
) )
else: else:
logging.warning( logging.warning(f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule")
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" return None
)
return
db.session.add(dataset_process_rule) db.session.add(dataset_process_rule)
db.session.commit() db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) return dataset_process_rule
@staticmethod
def _process_new_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
) -> tuple[list, str]:
"""
Process new documents with distributed locking to prevent race conditions.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule for document processing
created_from: Source identifier
batch: Batch identifier
Returns:
tuple: (list of created documents, batch identifier)
"""
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
document_ids = [] document_ids: list[str] = []
duplicate_document_ids = [] duplicate_document_ids: list[str] = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore documents: list[Document] = []
data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
# Process documents based on data source type
if data_source_type == "upload_file":
documents, document_ids, duplicate_document_ids = DocumentService._process_upload_file_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position
)
elif data_source_type == "notion_import":
documents, document_ids = DocumentService._process_notion_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position
)
elif data_source_type == "website_crawl":
documents, document_ids = DocumentService._process_website_documents(
dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position
)
db.session.commit()
# Trigger asynchronous indexing tasks
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@staticmethod
def _process_upload_file_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
position: int,
) -> tuple[list[Document], list[str], list[str]]:
"""
Process uploaded file documents with duplicate checking and validation.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule
created_from: Source identifier
batch: Batch identifier
position: Starting position for document ordering
Returns:
tuple: (list of documents, list of document IDs, list of duplicate document IDs)
Raises:
FileNotExistsError: When referenced files are not found
"""
documents: list[Document] = []
document_ids: list[str] = []
duplicate_document_ids: list[str] = []
current_position = position
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list: for file_id in upload_file_list:
# Validate file existence
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
# raise error if file not found
if not file: if not file:
raise FileNotExistsError() raise FileNotExistsError()
file_name = file.name file_name = file.name
data_source_info = { data_source_info = {"upload_file_id": file_id}
"upload_file_id": file_id,
} # Handle duplicate document processing
# check duplicate
if knowledge_config.duplicate: if knowledge_config.duplicate:
document = ( document = DocumentService._find_duplicate_document(dataset, file_name)
if document:
document = DocumentService._update_duplicate_document(
document, dataset_process_rule, knowledge_config, created_from, data_source_info, batch
)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
# Create new document
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
current_position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
current_position += 1
return documents, document_ids, duplicate_document_ids
@staticmethod
def _find_duplicate_document(dataset: Dataset, file_name: str) -> Optional[Document]:
"""
Find existing document with the same name in the dataset.
Args:
dataset: Target dataset
file_name: Name of the file to check for duplicates
Returns:
Document: Existing document if found, None otherwise
"""
return (
db.session.query(Document) db.session.query(Document)
.filter_by( .filter_by(
dataset_id=dataset.id, dataset_id=dataset.id,
@ -1192,7 +1466,30 @@ class DocumentService:
) )
.first() .first()
) )
if document:
@staticmethod
def _update_duplicate_document(
document: Document,
dataset_process_rule: Optional[DatasetProcessRule],
knowledge_config: KnowledgeConfig,
created_from: str,
data_source_info: dict,
batch: str,
) -> Document:
"""
Update existing document with new configuration and mark for re-indexing.
Args:
document: Document to be updated
dataset_process_rule: Process rule to apply
knowledge_config: New configuration
created_from: Source identifier
data_source_info: Data source information
batch: Batch identifier
Returns:
Document: Updated document
"""
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from document.created_from = created_from
@ -1202,9 +1499,66 @@ class DocumentService:
document.batch = batch document.batch = batch
document.indexing_status = "waiting" document.indexing_status = "waiting"
db.session.add(document) db.session.add(document)
documents.append(document) return document
duplicate_document_ids.append(document.id)
continue @staticmethod
def _process_notion_documents(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule],
created_from: str,
batch: str,
position: int,
) -> tuple[list[Document], list[str]]:
"""
Process Notion import documents with workspace validation and page deduplication.
Args:
dataset: Target dataset
knowledge_config: Document configuration
account: User account
dataset_process_rule: Process rule
created_from: Source identifier
batch: Batch identifier
position: Starting position for document ordering
Returns:
tuple: (list of documents, list of document IDs)
Raises:
ValueError: When no notion info list is found or data source binding is missing
"""
documents: list[Document] = []
document_ids: list[str] = []
current_position = position
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
# Get existing Notion documents for deduplication
exist_page_ids, exist_document = DocumentService._get_existing_notion_documents(dataset)
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
# Validate data source binding
data_source_binding = DocumentService._validate_notion_binding(workspace_id)
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
# Create new Notion document
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"type": page.type,
}
# Truncate page name to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id, # type: ignore
@ -1213,23 +1567,40 @@ class DocumentService:
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
created_from, created_from,
position, current_position,
account, account,
file_name, truncated_page_name,
batch, batch,
) )
db.session.add(document) db.session.add(document)
db.session.flush() db.session.flush()
document_ids.append(document.id) document_ids.append(document.id)
documents.append(document) documents.append(document)
position += 1 current_position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore else:
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore exist_document.pop(page.page_id)
if not notion_info_list:
raise ValueError("No notion info list found.") # Clean up unselected documents
exist_page_ids = [] if len(exist_document) > 0:
exist_document = {} clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
documents_from_db = (
return documents, document_ids
@staticmethod
def _get_existing_notion_documents(dataset: Dataset) -> tuple[list[str], dict[str, str]]:
"""
Retrieve existing Notion documents for deduplication purposes.
Args:
dataset: Target dataset
Returns:
tuple: (list of existing page IDs, dict mapping page IDs to document IDs)
"""
exist_page_ids: list[str] = []
exist_document: dict[str, str] = {}
documents = (
db.session.query(Document) db.session.query(Document)
.filter_by( .filter_by(
dataset_id=dataset.id, dataset_id=dataset.id,
@ -1239,13 +1610,29 @@ class DocumentService:
) )
.all() .all()
) )
if documents_from_db:
for document in documents_from_db: if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"]) exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id return exist_page_ids, exist_document
@staticmethod
def _validate_notion_binding(workspace_id: str) -> DataSourceOauthBinding:
"""
Validate Notion data source binding for the specified workspace.
Args:
workspace_id: Notion workspace identifier
Returns:
DataSourceOauthBinding: Valid binding for the workspace
Raises:
ValueError: When data source binding is not found
"""
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .filter(
@ -1258,46 +1645,50 @@ class DocumentService:
) )
.first() .first()
) )
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")
for page in notion_info.pages:
if page.page_id not in exist_page_ids: return data_source_binding
data_source_info = {
"notion_workspace_id": workspace_id, @staticmethod
"notion_page_id": page.page_id, def _process_website_documents(
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, dataset: Dataset,
"type": page.type, knowledge_config: KnowledgeConfig,
} account: Account,
# Truncate page name to 255 characters to prevent DB field length errors dataset_process_rule: Optional[DatasetProcessRule],
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" created_from: str,
document = DocumentService.build_document( batch: str,
dataset, position: int,
dataset_process_rule.id, # type: ignore ) -> tuple[list[Document], list[str]]:
knowledge_config.data_source.info_list.data_source_type, # type: ignore """
knowledge_config.doc_form, Process website crawl documents with URL validation and naming.
knowledge_config.doc_language,
data_source_info, Args:
created_from, dataset: Target dataset
position, knowledge_config: Document configuration
account, account: User account
truncated_page_name, dataset_process_rule: Process rule
batch, created_from: Source identifier
) batch: Batch identifier
db.session.add(document) position: Starting position for document ordering
db.session.flush()
document_ids.append(document.id) Returns:
documents.append(document) tuple: (list of documents, list of document IDs)
position += 1
else: Raises:
exist_document.pop(page.page_id) ValueError: When no website info list is found
# delete not selected documents """
if len(exist_document) > 0: documents: list[Document] = []
clean_notion_document_task.delay(list(exist_document.values()), dataset.id) document_ids: list[str] = []
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore current_position = position
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if not website_info: if not website_info:
raise ValueError("No website info list found.") raise ValueError("No website info list found.")
urls = website_info.urls urls = website_info.urls
for url in urls: for url in urls:
data_source_info = { data_source_info = {
"url": url, "url": url,
@ -1306,10 +1697,10 @@ class DocumentService:
"only_main_content": website_info.only_main_content, "only_main_content": website_info.only_main_content,
"mode": "crawl", "mode": "crawl",
} }
if len(url) > 255:
document_name = url[:200] + "..." # Truncate URL for document naming if too long
else: document_name = url[:200] + "..." if len(url) > 255 else url
document_name = url
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id, # type: ignore
@ -1318,7 +1709,7 @@ class DocumentService:
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
created_from, created_from,
position, current_position,
account, account,
document_name, document_name,
batch, batch,
@ -1327,16 +1718,9 @@ class DocumentService:
db.session.flush() db.session.flush()
document_ids.append(document.id) document_ids.append(document.id)
documents.append(document) documents.append(document)
position += 1 current_position += 1
db.session.commit()
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch return documents, document_ids
@staticmethod @staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel): def check_documents_upload_quota(count: int, features: FeatureModel):

Loading…
Cancel
Save