|
|
|
|
@ -252,7 +252,7 @@ class IndexingRunner:
|
|
|
|
|
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"total_segments": len(total_segments),
|
|
|
|
|
"total_segments": total_segments,
|
|
|
|
|
"tokens": tokens,
|
|
|
|
|
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
|
|
|
|
|
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
|
|
|
|
@ -261,25 +261,30 @@ class IndexingRunner:
|
|
|
|
|
|
|
|
|
|
def _load_data(self, document: Document) -> List[Document]:
|
|
|
|
|
# load file
|
|
|
|
|
if document.data_source_type != "upload_file":
|
|
|
|
|
if document.data_source_type not in ["upload_file", "notion_import"]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
data_source_info = document.data_source_info_dict
|
|
|
|
|
if not data_source_info or 'upload_file_id' not in data_source_info:
|
|
|
|
|
raise ValueError("no upload file found")
|
|
|
|
|
|
|
|
|
|
file_detail = db.session.query(UploadFile). \
|
|
|
|
|
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
|
|
|
|
one_or_none()
|
|
|
|
|
|
|
|
|
|
text_docs = self._load_data_from_file(file_detail)
|
|
|
|
|
|
|
|
|
|
text_docs = []
|
|
|
|
|
if document.data_source_type == 'upload_file':
|
|
|
|
|
if not data_source_info or 'upload_file_id' not in data_source_info:
|
|
|
|
|
raise ValueError("no upload file found")
|
|
|
|
|
|
|
|
|
|
file_detail = db.session.query(UploadFile). \
|
|
|
|
|
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
|
|
|
|
one_or_none()
|
|
|
|
|
|
|
|
|
|
text_docs = self._load_data_from_file(file_detail)
|
|
|
|
|
elif document.data_source_type == 'notion_import':
|
|
|
|
|
if not data_source_info or 'notion_page_id' not in data_source_info \
|
|
|
|
|
or 'notion_workspace_id' not in data_source_info:
|
|
|
|
|
raise ValueError("no notion page found")
|
|
|
|
|
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'])
|
|
|
|
|
# update document status to splitting
|
|
|
|
|
self._update_document_index_status(
|
|
|
|
|
document_id=document.id,
|
|
|
|
|
after_indexing_status="splitting",
|
|
|
|
|
extra_update_params={
|
|
|
|
|
Document.file_id: file_detail.id,
|
|
|
|
|
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
|
|
|
|
|
Document.parsing_completed_at: datetime.datetime.utcnow()
|
|
|
|
|
}
|
|
|
|
|
@ -314,6 +319,22 @@ class IndexingRunner:
|
|
|
|
|
|
|
|
|
|
return text_docs
|
|
|
|
|
|
|
|
|
|
def _load_data_from_notion(self, workspace_id: str, page_id: str) -> List[Document]:
|
|
|
|
|
data_source_binding = DataSourceBinding.query.filter(
|
|
|
|
|
db.and_(
|
|
|
|
|
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
|
|
|
|
DataSourceBinding.provider == 'notion',
|
|
|
|
|
DataSourceBinding.disabled == False,
|
|
|
|
|
DataSourceBinding.source_info['workspace_id'] == workspace_id
|
|
|
|
|
)
|
|
|
|
|
).first()
|
|
|
|
|
if not data_source_binding:
|
|
|
|
|
raise ValueError('Data source binding not found.')
|
|
|
|
|
page_ids = [page_id]
|
|
|
|
|
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
|
|
|
|
text_docs = reader.load_data(page_ids=page_ids)
|
|
|
|
|
return text_docs
|
|
|
|
|
|
|
|
|
|
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
|
|
|
|
|
"""
|
|
|
|
|
Get the NodeParser object according to the processing rule.
|
|
|
|
|
|