|
|
|
@ -9,7 +9,7 @@ from collections import Counter
|
|
|
|
from typing import Any, Optional
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from flask_login import current_user
|
|
|
|
from flask_login import current_user
|
|
|
|
from sqlalchemy import func
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from werkzeug.exceptions import NotFound
|
|
|
|
from werkzeug.exceptions import NotFound
|
|
|
|
|
|
|
|
|
|
|
|
@ -77,7 +77,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
|
|
|
class DatasetService:
|
|
|
|
class DatasetService:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
|
|
|
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
|
|
|
query = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
|
|
|
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
|
|
|
|
|
|
|
|
|
|
|
if user:
|
|
|
|
if user:
|
|
|
|
# get permitted dataset ids
|
|
|
|
# get permitted dataset ids
|
|
|
|
@ -155,11 +155,10 @@ class DatasetService:
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_datasets_by_ids(ids, tenant_id):
|
|
|
|
def get_datasets_by_ids(ids, tenant_id):
|
|
|
|
datasets = (
|
|
|
|
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
|
|
|
|
db.session.query(Dataset)
|
|
|
|
|
|
|
|
.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
|
|
|
|
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
|
|
|
.paginate(page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return datasets.items, datasets.total
|
|
|
|
return datasets.items, datasets.total
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@ -512,12 +511,10 @@ class DatasetService:
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
|
|
|
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
|
|
|
dataset_queries = (
|
|
|
|
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
|
|
|
db.session.query(DatasetQuery)
|
|
|
|
|
|
|
|
.filter_by(dataset_id=dataset_id)
|
|
|
|
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
|
|
|
.order_by(db.desc(DatasetQuery.created_at))
|
|
|
|
|
|
|
|
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return dataset_queries.items, dataset_queries.total
|
|
|
|
return dataset_queries.items, dataset_queries.total
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@ -2186,7 +2183,7 @@ class SegmentService:
|
|
|
|
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
|
|
|
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
|
|
|
):
|
|
|
|
):
|
|
|
|
query = (
|
|
|
|
query = (
|
|
|
|
db.session.query(ChildChunk)
|
|
|
|
select(ChildChunk)
|
|
|
|
.filter_by(
|
|
|
|
.filter_by(
|
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
dataset_id=dataset_id,
|
|
|
|
@ -2197,7 +2194,7 @@ class SegmentService:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if keyword:
|
|
|
|
if keyword:
|
|
|
|
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
|
|
|
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
|
|
|
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
|
|
|
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
|
|
|
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
|
|
|
@ -2220,7 +2217,7 @@ class SegmentService:
|
|
|
|
limit: int = 20,
|
|
|
|
limit: int = 20,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""Get segments for a document with optional filtering."""
|
|
|
|
"""Get segments for a document with optional filtering."""
|
|
|
|
query = db.session.query(DocumentSegment).filter(
|
|
|
|
query = select(DocumentSegment).filter(
|
|
|
|
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
|
|
|
|
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@ -2230,9 +2227,8 @@ class SegmentService:
|
|
|
|
if keyword:
|
|
|
|
if keyword:
|
|
|
|
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
|
|
|
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
|
|
|
|
|
|
|
|
|
|
|
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
|
|
|
|
query = query.order_by(DocumentSegment.position.asc())
|
|
|
|
page=page, per_page=limit, max_per_page=100, error_out=False
|
|
|
|
paginated_segments = db.paginate(query=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return paginated_segments.items, paginated_segments.total
|
|
|
|
return paginated_segments.items, paginated_segments.total
|
|
|
|
|
|
|
|
|
|
|
|
|