From cef3a2d68fad57f39f80622b0726e7c067e1343a Mon Sep 17 00:00:00 2001 From: hjlarry Date: Mon, 12 May 2025 10:35:11 +0800 Subject: [PATCH] fix CI --- api/services/dataset_service.py | 34 ++++++++++------------ api/services/external_knowledge_service.py | 7 +++-- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 77ea794662..d4e540efe9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,7 +9,7 @@ from collections import Counter from typing import Any, Optional from flask_login import current_user -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -77,7 +77,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): - query = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids @@ -155,11 +155,10 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = ( - db.session.query(Dataset) - .filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) - .paginate(page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) - ) + stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) + + datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) + return datasets.items, datasets.total @staticmethod @@ -512,12 +511,10 @@ class DatasetService: @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = ( - db.session.query(DatasetQuery) - .filter_by(dataset_id=dataset_id) - .order_by(db.desc(DatasetQuery.created_at)) - .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) - ) + stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) + + dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) + return dataset_queries.items, dataset_queries.total @staticmethod @@ -2186,7 +2183,7 @@ class SegmentService: cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None ): query = ( - db.session.query(ChildChunk) + select(ChildChunk) .filter_by( tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, @@ -2197,7 +2194,7 @@ class SegmentService: ) if keyword: query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) - return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: @@ -2220,7 +2217,7 @@ class SegmentService: limit: int = 20, ): """Get segments for a document with optional filtering.""" - query = db.session.query(DocumentSegment).filter( + query = select(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) @@ -2230,9 +2227,8 @@ class SegmentService: if keyword: query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) - paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( - page=page, per_page=limit, max_per_page=100, error_out=False - ) + query = query.order_by(DocumentSegment.position.asc()) + paginated_segments = db.paginate(query=query, page=page, per_page=limit, max_per_page=100, error_out=False) return paginated_segments.items, paginated_segments.total diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 5e671e0f4c..1a63286099 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast from urllib.parse import urlparse import httpx +from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy @@ -26,14 +27,16 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: query = ( - db.session.query(ExternalKnowledgeApis) + select(ExternalKnowledgeApis) .filter(ExternalKnowledgeApis.tenant_id == tenant_id) .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) - external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + external_knowledge_apis = db.paginate( + select=query, page=page, per_page=per_page, max_per_page=100, error_out=False + ) return external_knowledge_apis.items, external_knowledge_apis.total