From 00913a0843ea99bba3c757ddc66cd2c0ea2ea061 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Mon, 12 May 2025 19:59:27 +0800 Subject: [PATCH] model.query change to db.session.query --- api/commands.py | 7 +++--- .../console/explore/installed_app.py | 10 +++++---- .../index_tool_callback_handler.py | 10 ++++++--- api/core/indexing_runner.py | 6 ++--- api/core/rag/extractor/notion_extractor.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 22 ++++++++++++------- .../dataset_retriever_tool.py | 14 +++++++----- api/schedule/clean_messages.py | 4 +++- .../mail_clean_document_notify_task.py | 2 +- api/services/vector_service.py | 8 ++++--- 10 files changed, 53 insertions(+), 32 deletions(-) diff --git a/api/commands.py b/api/commands.py index c05ed786aa..66278a53a3 100644 --- a/api/commands.py +++ b/api/commands.py @@ -552,11 +552,12 @@ def old_metadata_migration(): page = 1 while True: try: - documents = ( - DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None) + stmt = ( + select(DatasetDocument) + .filter(DatasetDocument.doc_metadata.is_not(None)) .order_by(DatasetDocument.created_at.desc()) - .paginate(page=page, per_page=50) ) + documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) except NotFound: break if not documents: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9336c35a0d..4062972d08 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -66,7 +66,7 @@ class InstalledAppsListApi(Resource): parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: raise NotFound("App not found") @@ -79,9 +79,11 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter( - and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .first() + ) if installed_app is None: # todo: position diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 7908bd0467..fd818d9a27 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -42,9 +42,13 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: if document.metadata is not None: - dataset_document = DatasetDocument.query.filter( - DatasetDocument.id == document.metadata["document_id"] - ).first() + dataset_document = ( + db.session.query(DatasetDocument) + .filter(DatasetDocument.id == document.metadata["document_id"]) + .first() + ) + if not dataset_document: + continue if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( db.session.query(ChildChunk) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c389496801..848d897779 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -660,10 +660,10 @@ class IndexingRunner: """ Update the document indexing status. """ - count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() + count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count() if count > 0: raise DocumentIsPausedError() - document = DatasetDocument.query.filter_by(id=document_id).first() + document = db.session.query(DatasetDocument).filter_by(id=document_id).first() if not document: raise DocumentIsDeletedPausedError() @@ -672,7 +672,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - DatasetDocument.query.filter_by(id=document_id).update(update_params) + db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) db.session.commit() @staticmethod diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 00a2150875..4e14800d0a 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor): data_source_info["last_edited_time"] = last_edited_time update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} - DocumentModel.query.filter_by(id=document_model.id).update(update_params) + db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params) db.session.commit() def get_notion_last_edited_time(self) -> str: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 444d7ee329..d3605da146 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -238,11 +238,15 @@ class DatasetRetrieval: for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() + document = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .first() + ) if dataset and document: source = { "dataset_id": dataset.id, @@ -506,9 +510,11 @@ class DatasetRetrieval: dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: if document.metadata is not None: - dataset_document = DatasetDocument.query.filter( - DatasetDocument.id == document.metadata["document_id"] - ).first() + dataset_document = ( + db.session.query(DatasetDocument) + .filter(DatasetDocument.id == document.metadata["document_id"]) + .first() + ) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index c19c357d2a..fff261e0bd 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -186,11 +186,15 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() + document = ( + db.session.query(DatasetDocument) # type: ignore + .filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .first() + ) if dataset and document: source = { "dataset_id": dataset.id, diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 5e4d3ec323..b213b154e7 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -46,7 +46,9 @@ def clean_messages(): break for message in messages: plan_sandbox_clean_message_day = message.created_at - app = App.query.filter_by(id=message.app_id).first() + app = db.session.query(App).filter_by(id=message.app_id).first() + if not app: + continue features_cache_key = f"features:{app.tenant_id}" plan_cache = redis_client.get(features_cache_key) if plan_cache is None: diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 0b3ff5d47d..5ee813e1de 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -54,7 +54,7 @@ def mail_clean_document_notify_task(): ) if not current_owner_join: continue - account = Account.query.filter(Account.id == current_owner_join.account_id).first() + account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first() if not account: continue diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 92422bf29d..696bcd2667 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -22,7 +22,9 @@ class VectorService: for segment in segments: if doc_form == IndexType.PARENT_CHILD_INDEX: - document = DatasetDocument.query.filter_by(id=segment.document_id).first() + document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() + if not document: + continue # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) @@ -52,7 +54,7 @@ class VectorService: raise ValueError("The knowledge base index technique is not high quality!") cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) else: - document = Document( + document = Document( # type: ignore page_content=segment.content, metadata={ "doc_id": segment.index_node_id, @@ -64,7 +66,7 @@ class VectorService: documents.append(document) if len(documents) > 0: index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) # type: ignore @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):