refactor: document segment query

pull/12372/head
Yeuoly 2 years ago
parent d36dece0af
commit 685e8cdc7d
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -170,46 +170,47 @@ class DatasetDocumentListApi(Resource):
raise Forbidden(str(e)) raise Forbidden(str(e))
with Session(db.engine) as session: with Session(db.engine) as session:
query = session.execute( query = session.query(Document).filter_by(
select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id
).all() )
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.filter(Document.name.like(search))
if sort.startswith("-"): if sort.startswith("-"):
sort_logic = desc sort_logic = desc
sort = sort[1:] sort = sort[1:]
else: else:
sort_logic = asc sort_logic = asc
if sort == "hit_count": if sort == "hit_count":
sub_query = ( sub_query = (
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) db.select(
.group_by(DocumentSegment.document_id) DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")
.subquery() )
) .group_by(DocumentSegment.document_id)
.subquery()
)
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position), sort_logic(Document.position),
) )
elif sort == "created_at": elif sort == "created_at":
query = query.order_by( query = query.order_by(
sort_logic(Document.created_at), sort_logic(Document.created_at),
sort_logic(Document.position), sort_logic(Document.position),
) )
else: else:
query = query.order_by( query = query.order_by(
desc(Document.created_at), desc(Document.created_at),
desc(Document.position), desc(Document.position),
) )
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
if fetch: if fetch:
with Session(db.engine) as session:
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
session.query(DocumentSegment) session.query(DocumentSegment)
@ -228,17 +229,17 @@ class DatasetDocumentListApi(Resource):
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields) data = marshal(documents, document_with_segments_fields)
else: else:
data = marshal(documents, document_fields) data = marshal(documents, document_fields)
response = { response = {
"data": data, "data": data,
"has_more": len(documents) == limit, "has_more": len(documents) == limit,
"limit": limit, "limit": limit,
"total": paginated_documents.total, "total": paginated_documents.total,
"page": page, "page": page,
} }
return response return response
documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}

Loading…
Cancel
Save