diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b2d61992c3..55193da432 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,7 @@ # -*- coding:utf-8 -*- import random from datetime import datetime +from typing import List from flask import request from flask_login import login_required, current_user @@ -83,6 +84,22 @@ class DocumentResource(Resource): return document + def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + documents = DocumentService.get_batch_documents(dataset_id, batch) + + if not documents: + raise NotFound('Documents not found.') + + return documents class GetProcessRuleApi(Resource): @setup_required @@ -340,23 +357,25 @@ class DocumentIndexingStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id, document_id): + def get(self, dataset_id, batch): dataset_id = str(dataset_id) - document_id = str(document_id) - document = self.get_document(dataset_id, document_id) - - completed_segments = DocumentSegment.query \ - .filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id)) \ - .count() - total_segments = DocumentSegment.query \ - .filter_by(document_id=str(document_id)) \ - .count() + batch = str(batch) + documents = self.get_batch_documents(dataset_id, batch) + documents_status = [] + for document in documents: + completed_segments = DocumentSegment.query \ + .filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id)) \ + .count() + total_segments = DocumentSegment.query \ + .filter_by(document_id=str(document.id)) \ + .count() - document.completed_segments = completed_segments - document.total_segments = total_segments + document.completed_segments = completed_segments + document.total_segments = total_segments + documents_status.append(marshal(document, self.document_status_fields)) - return marshal(document, self.document_status_fields) + return documents_status class DocumentDetailApi(DocumentResource): @@ -676,7 +695,7 @@ api.add_resource(DatasetInitApi, api.add_resource(DocumentIndexingEstimateApi, '/datasets//documents//indexing-estimate') api.add_resource(DocumentIndexingStatusApi, - '/datasets//documents//indexing-status') + '/datasets//batch//indexing-status') api.add_resource(DocumentDetailApi, '/datasets//documents/') api.add_resource(DocumentProcessingApi, diff --git a/api/core/data_source/notion.py b/api/core/data_source/notion.py index 7bb693c308..efbaf9137f 100644 --- a/api/core/data_source/notion.py +++ b/api/core/data_source/notion.py @@ -141,7 +141,7 @@ class NotionPageReader(BaseReader): def read_page_as_documents(self, page_id: str) -> List[str]: """Read a page as documents.""" - return self._read_block(page_id) + return self._read_parent_blocks(page_id) def query_database( self, database_id: str, query_dict: Dict[str, Any] = {} @@ -212,6 +212,26 @@ class NotionPageReader(BaseReader): return docs + def load_data_as_documents( + self, page_ids: List[str] = [], database_id: Optional[str] = None + ) -> List[Document]: + if not page_ids and not database_id: + raise ValueError("Must specify either `page_ids` or `database_id`.") + docs = [] + if database_id is not None: + # get all the pages in the database + page_ids = self.query_database(database_id) + for page_id in page_ids: + page_text = self.read_page(page_id) + docs.append(Document(page_text, extra_info={"page_id": page_id})) + else: + for page_id in page_ids: + page_text_list = self.read_page_as_documents(page_id) + for page_text in page_text_list: + docs.append(Document(page_text, extra_info={"page_id": page_id})) + + return docs + if __name__ == "__main__": reader = NotionPageReader() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index db200b2ab8..f9dd7e99ff 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -332,7 +332,7 @@ class IndexingRunner: raise ValueError('Data source binding not found.') page_ids = [page_id] reader = NotionPageReader(integration_token=data_source_binding.access_token) - text_docs = reader.load_data(page_ids=page_ids) + text_docs = reader.load_data_as_documents(page_ids=page_ids) return text_docs def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: diff --git a/api/models/source.py b/api/models/source.py index 33eb38eea3..870797ccca 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import JSONB class DataSourceBinding(db.Model): __tablename__ = 'data_source_bindings' __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_pkey'), + db.PrimaryKeyConstraint('id', name='source_binding_pkey'), db.Index('app_tenant_id_idx', 'tenant_id') ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 793070dfff..bcf72ed12b 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -3,7 +3,7 @@ import logging import datetime import time import random -from typing import Optional +from typing import Optional, List from extensions.ext_redis import redis_client from flask_login import current_user @@ -278,6 +278,15 @@ class DocumentService: return document @staticmethod + def get_batch_documents(dataset_id: str, batch: str) -> List[Document]: + documents = db.session.query(Document).filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id + ).all() + + return documents + @staticmethod def get_document_file_detail(file_id: str): file_detail = db.session.query(UploadFile). \ filter(UploadFile.id == file_id). \ @@ -376,6 +385,7 @@ class DocumentService: db.session.add(dataset_process_rule) db.session.commit() position = DocumentService.get_documents_position(dataset.id) + batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) document_ids = [] documents = [] if document_data["data_source"]["type"] == "upload_file": @@ -398,7 +408,7 @@ class DocumentService: document = DocumentService.save_document(dataset, dataset_process_rule.id, document_data["data_source"]["type"], data_source_info, created_from, position, - account, file_name) + account, file_name, batch) db.session.add(document) db.session.flush() document_ids.append(document.id) @@ -426,7 +436,7 @@ class DocumentService: document = DocumentService.save_document(dataset, dataset_process_rule.id, document_data["data_source"]["type"], data_source_info, created_from, position, - account, page['page_name']) + account, page['page_name'], batch) db.session.add(document) db.session.flush() document_ids.append(document.id) @@ -442,7 +452,7 @@ class DocumentService: @staticmethod def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, - created_from: str, position: int, account: Account, name: str): + created_from: str, position: int, account: Account, name: str, batch: str): document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -450,7 +460,7 @@ class DocumentService: data_source_type=data_source_type, data_source_info=json.dumps(data_source_info), dataset_process_rule_id=process_rule_id, - batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), + batch=batch, name=name, created_from=created_from, created_by=account.id,