diff --git a/api/commands.py b/api/commands.py index 3a9589da4d..1792a8a67b 100644 --- a/api/commands.py +++ b/api/commands.py @@ -7,6 +7,7 @@ from typing import Optional import click from flask import current_app from werkzeug.exceptions import NotFound +from sqlalchemy import select from configs import dify_config from constants.languages import languages @@ -297,12 +298,11 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = ( - db.session.query(Dataset) - .filter(Dataset.indexing_technique == "high_quality") - .order_by(Dataset.created_at.desc()) - .paginate(page=page, per_page=50) + stmt = ( + select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) ) + + datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) except NotFound: break diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 54d2812886..ca18c25e74 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,7 +6,7 @@ from typing import cast from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal, marshal_with, reqparse -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource): limits = DocumentService.DEFAULT_RULES["limits"] if document_id: # get the latest process rule - document = db.session.query(Document).get_or_404(document_id) + document = db.get_or_404(Document, document_id) dataset = DatasetService.get_dataset(document.dataset_id) @@ -175,9 +175,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = db.session.query(Document).filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id - ) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: search = f"%{search}%" @@ -211,7 +209,7 @@ class DatasetDocumentListApi(Resource): desc(Document.position), ) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2ed75cc346..8913765de4 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -5,6 +5,7 @@ from flask import request from flask_login import current_user from flask_restful import Resource, marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound +from sqlalchemy import select import services from controllers.console import api @@ -76,7 +77,7 @@ class DatasetDocumentSegmentListApi(Resource): keyword = args["keyword"] query = ( - db.session.query(DocumentSegment) + select(DocumentSegment) .filter( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id, @@ -99,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) response = { "data": marshal(segments.items, segment_fields), diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index be8289914e..46d4f1a5ce 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -2,7 +2,7 @@ import json from flask import request from flask_restful import marshal, reqparse -from sqlalchemy import desc +from sqlalchemy import desc, select from werkzeug.exceptions import NotFound from controllers.common.errors import FilenameNotExistsError @@ -24,6 +24,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +import services from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -336,7 +337,7 @@ class DocumentListApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - query = db.session.query(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: search = f"%{search}%" @@ -344,7 +345,7 @@ class DocumentListApi(DatasetApiResource): query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 0aed0f1e2b..c0cd42a226 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -2,7 +2,7 @@ import datetime import time import click -from sqlalchemy import func +from sqlalchemy import func, select from werkzeug.exceptions import NotFound import app @@ -51,8 +51,8 @@ def clean_unused_datasets_task(): ) # Main query with join and filter - datasets = ( - db.session.query(Dataset) + stmt = ( + select(Dataset) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( @@ -61,9 +61,10 @@ def clean_unused_datasets_task(): func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) .order_by(Dataset.created_at.desc()) - .paginate(page=1, per_page=50) ) + datasets = db.paginate(stmt, page=1, per_page=50) + except NotFound: break if datasets.items is None or len(datasets.items) == 0: @@ -136,8 +137,8 @@ def clean_unused_datasets_task(): ) # Main query with join and filter - datasets = ( - db.session.query(Dataset) + stmt = ( + select(Dataset) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( @@ -146,8 +147,8 @@ def clean_unused_datasets_task(): func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) .order_by(Dataset.created_at.desc()) - .paginate(page=1, per_page=50) ) + datasets = db.paginate(stmt, page=1, per_page=50) except NotFound: break diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d4e540efe9..f9fe39f977 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -131,7 +131,7 @@ class DatasetService: else: return [], 0 - datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @@ -2228,7 +2228,7 @@ class SegmentService: query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.order_by(DocumentSegment.position.asc()) - paginated_segments = db.paginate(query=query, page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) return paginated_segments.items, paginated_segments.total