pull/19509/head
hjlarry 1 year ago
parent eeeae7571a
commit cef3a2d68f

@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -77,7 +77,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
@ -155,11 +155,10 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = (
db.session.query(Dataset)
.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
.paginate(page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
)
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
@ -512,12 +511,10 @@ class DatasetService:
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = (
db.session.query(DatasetQuery)
.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at))
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total
@staticmethod
@ -2186,7 +2183,7 @@ class SegmentService:
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
query = (
db.session.query(ChildChunk)
select(ChildChunk)
.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
@ -2197,7 +2194,7 @@ class SegmentService:
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
@ -2220,7 +2217,7 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = db.session.query(DocumentSegment).filter(
query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
@ -2230,9 +2227,8 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(query=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total

@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -26,14 +27,16 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = (
db.session.query(ExternalKnowledgeApis)
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total

Loading…
Cancel
Save