replace all dataset.Model.query to db.session.query(Model)

pull/19509/head
hjlarry 1 year ago
parent 87da155477
commit bac9eefac5

@ -298,7 +298,8 @@ def migrate_knowledge_vector_database():
while True: while True:
try: try:
datasets = ( datasets = (
Dataset.query.filter(Dataset.indexing_technique == "high_quality") db.session.query(Dataset)
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc()) .order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50) .paginate(page=page, per_page=50)
) )
@ -592,11 +593,15 @@ def old_metadata_migration():
) )
db.session.add(dataset_metadata_binding) db.session.add(dataset_metadata_binding)
else: else:
dataset_metadata_binding = DatasetMetadataBinding.query.filter( dataset_metadata_binding = (
DatasetMetadataBinding.dataset_id == document.dataset_id, db.session.query(DatasetMetadataBinding)
DatasetMetadataBinding.document_id == document.id, .filter(
DatasetMetadataBinding.metadata_id == dataset_metadata.id, DatasetMetadataBinding.dataset_id == document.dataset_id,
).first() DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.first()
)
if not dataset_metadata_binding: if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id, tenant_id=document.tenant_id,

@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
) )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields)) documents_status.append(marshal(document, document_status_fields))

@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource):
limits = DocumentService.DEFAULT_RULES["limits"] limits = DocumentService.DEFAULT_RULES["limits"]
if document_id: if document_id:
# get the latest process rule # get the latest process rule
document = Document.query.get_or_404(document_id) document = db.session.query(Document).get_or_404(document_id)
dataset = DatasetService.get_dataset(document.dataset_id) dataset = DatasetService.get_dataset(document.dataset_id)
@ -175,7 +175,9 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) query = db.session.query(Document).filter_by(
dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id
)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -213,14 +215,20 @@ class DatasetDocumentListApi(Resource):
documents = paginated_documents.items documents = paginated_documents.items
if fetch: if fetch:
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields) data = marshal(documents, document_with_segments_fields)
@ -563,14 +571,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
@ -589,14 +603,20 @@ class DocumentIndexingStatusApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document_id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document_id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments

@ -74,9 +74,14 @@ class DatasetDocumentSegmentListApi(Resource):
hit_count_gte = args["hit_count_gte"] hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"] keyword = args["keyword"]
query = DocumentSegment.query.filter( query = (
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).order_by(DocumentSegment.position.asc()) .filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -276,9 +281,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -320,9 +327,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -423,9 +432,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -478,9 +489,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -523,9 +536,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -567,16 +582,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# check child chunk # check child chunk
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id db.session.query(ChildChunk)
).first() .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -612,16 +631,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# check child chunk # check child chunk
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id db.session.query(ChildChunk)
).first() .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor

@ -5,7 +5,6 @@ from flask_restful import marshal, reqparse
from sqlalchemy import desc from sqlalchemy import desc
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service
from controllers.common.errors import FilenameNotExistsError from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
@ -337,7 +336,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) query = db.session.query(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -374,14 +373,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
raise NotFound("Documents not found.") raise NotFound("Documents not found.")
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:

@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler:
DatasetDocument.id == document.metadata["document_id"] DatasetDocument.id == document.metadata["document_id"]
).first() ).first()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.index_node_id == document.metadata["doc_id"], db.session.query(ChildChunk)
ChildChunk.dataset_id == dataset_document.dataset_id, .filter(
ChildChunk.document_id == dataset_document.id, ChildChunk.index_node_id == document.metadata["doc_id"],
).first() ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk: if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( segment = (
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
) )
else: else:
query = db.session.query(DocumentSegment).filter( query = db.session.query(DocumentSegment).filter(

@ -51,7 +51,7 @@ class IndexingRunner:
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
@ -103,15 +103,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = (
dataset_id=dataset.id, document_id=dataset_document.id db.session.query(DocumentSegment)
).all() .filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
for document_segment in document_segments: for document_segment in document_segments:
db.session.delete(document_segment) db.session.delete(document_segment)
@ -162,15 +164,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is indexing.""" """Run the indexing process when the index_status is indexing."""
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = (
dataset_id=dataset.id, document_id=dataset_document.id db.session.query(DocumentSegment)
).all() .filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
documents = [] documents = []
if document_segments: if document_segments:
@ -254,7 +258,7 @@ class IndexingRunner:
embedding_model_instance = None embedding_model_instance = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset not found.") raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
@ -587,7 +591,7 @@ class IndexingRunner:
@staticmethod @staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents): def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context(): with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
keyword = Keyword(dataset) keyword = Keyword(dataset)
@ -676,7 +680,7 @@ class IndexingRunner:
""" """
Update the document segment by document id. Update the document segment by document id.
""" """
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
def _transform( def _transform(

@ -237,7 +237,7 @@ class DatasetRetrieval:
if show_retrieve_source: if show_retrieve_source:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -511,14 +511,23 @@ class DatasetRetrieval:
).first() ).first()
if dataset_document: if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.index_node_id == document.metadata["doc_id"], db.session.query(ChildChunk)
ChildChunk.dataset_id == dataset_document.dataset_id, .filter(
ChildChunk.document_id == dataset_document.id, ChildChunk.index_node_id == document.metadata["doc_id"],
).first() ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk: if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( segment = (
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
) )
db.session.commit() db.session.commit()
else: else:

@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = [] document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter( segments = (
DocumentSegment.dataset_id.in_(self.dataset_ids), db.session.query(DocumentSegment)
DocumentSegment.completed_at.isnot(None), .filter(
DocumentSegment.status == "completed", DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.enabled == True, DocumentSegment.completed_at.isnot(None),
DocumentSegment.index_node_id.in_(index_node_ids), DocumentSegment.status == "completed",
).all() DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
)
if segments: if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list = [] context_list = []
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = (
Document.id == segment.document_id, db.session.query(Document)
Document.enabled == True, .filter(
Document.archived == False, Document.id == segment.document_id,
).first() Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"position": resource_number, "position": resource_number,

@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource: if self.return_resource:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,

@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records: if records:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = (
Document.id == segment.document_id, db.session.query(Document)
Document.enabled == True, .filter(
Document.archived == False, Document.id == segment.document_id,
).first() Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"metadata": { "metadata": {

@ -93,7 +93,8 @@ class Dataset(Base):
@property @property
def latest_process_rule(self): def latest_process_rule(self):
return ( return (
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.first() .first()
) )
@ -138,7 +139,8 @@ class Dataset(Base):
@property @property
def word_count(self): def word_count(self):
return ( return (
Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count)))
.filter(Document.dataset_id == self.id) .filter(Document.dataset_id == self.id)
.scalar() .scalar()
) )
@ -440,12 +442,13 @@ class Document(Base):
@property @property
def segment_count(self): def segment_count(self):
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
@property @property
def hit_count(self): def hit_count(self):
return ( return (
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
.filter(DocumentSegment.document_id == self.id) .filter(DocumentSegment.document_id == self.id)
.scalar() .scalar()
) )
@ -892,7 +895,7 @@ class DatasetKeywordTable(Base):
return dct return dct
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=self.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset: if not dataset:
return None return None
if self.data_source_type == "database": if self.data_source_type == "database":

@ -52,7 +52,8 @@ def clean_unused_datasets_task():
# Main query with join and filter # Main query with join and filter
datasets = ( datasets = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) db.session.query(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .filter(
Dataset.created_at < plan_sandbox_clean_day, Dataset.created_at < plan_sandbox_clean_day,
@ -99,7 +100,7 @@ def clean_unused_datasets_task():
# update document # update document
update_params = {Document.enabled: False} update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e: except Exception as e:
@ -136,7 +137,8 @@ def clean_unused_datasets_task():
# Main query with join and filter # Main query with join and filter
datasets = ( datasets = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) db.session.query(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .filter(
Dataset.created_at < plan_pro_clean_day, Dataset.created_at < plan_pro_clean_day,
@ -175,7 +177,7 @@ def clean_unused_datasets_task():
# update document # update document
update_params = {Document.enabled: False} update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo( click.echo(
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")

@ -19,7 +19,9 @@ def create_tidb_serverless_task():
while True: while True:
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number: if idle_tidb_serverless_number >= tidb_serverless_number:
break break
# create tidb serverless # create tidb serverless

@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail # send document clean notify mail
try: try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id # group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs: for dataset_auto_disable_log in dataset_auto_disable_logs:
@ -65,7 +67,7 @@ def mail_clean_document_notify_task():
) )
for dataset_id, document_ids in dataset_auto_dataset_map.items(): for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset: if dataset:
document_count = len(document_ids) document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

@ -14,9 +14,11 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
tidb_serverless_list = TidbAuthBinding.query.filter( tidb_serverless_list = (
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" db.session.query(TidbAuthBinding)
).all() .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0: if len(tidb_serverless_list) == 0:
return return
# update tidb serverless status # update tidb serverless status

@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService: class DatasetService:
@staticmethod @staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) query = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user: if user:
# get permitted dataset ids # get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR: if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@ -153,8 +155,10 @@ class DatasetService:
@staticmethod @staticmethod
def get_datasets_by_ids(ids, tenant_id): def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( datasets = (
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False db.session.query(Dataset)
.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
.paginate(page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
) )
return datasets.items, datasets.total return datasets.items, datasets.total
@ -174,7 +178,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None, retrieval_model: Optional[RetrievalModel] = None,
): ):
# check if dataset name already exists # check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None embedding_model = None
if indexing_technique == "high_quality": if indexing_technique == "high_quality":
@ -235,7 +239,7 @@ class DatasetService:
@staticmethod @staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]: def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset return dataset
@staticmethod @staticmethod
@ -436,7 +440,7 @@ class DatasetService:
# update Retrieval model # update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"] filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data) db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit() db.session.commit()
if action: if action:
@ -460,7 +464,7 @@ class DatasetService:
@staticmethod @staticmethod
def dataset_use_check(dataset_id) -> bool: def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0: if count > 0:
return True return True
return False return False
@ -475,7 +479,9 @@ class DatasetService:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members": if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if ( if (
not user_permission not user_permission
and dataset.tenant_id != user.current_tenant_id and dataset.tenant_id != user.current_tenant_id
@ -499,14 +505,16 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any( if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
): ):
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod @staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int): def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = ( dataset_queries = (
DatasetQuery.query.filter_by(dataset_id=dataset_id) db.session.query(DatasetQuery)
.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at)) .order_by(db.desc(DatasetQuery.created_at))
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
) )
@ -515,7 +523,8 @@ class DatasetService:
@staticmethod @staticmethod
def get_related_apps(dataset_id: str): def get_related_apps(dataset_id: str):
return ( return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at)) .order_by(db.desc(AppDatasetJoin.created_at))
.all() .all()
) )
@ -530,10 +539,14 @@ class DatasetService:
} }
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( dataset_auto_disable_logs = (
DatasetAutoDisableLog.dataset_id == dataset_id, db.session.query(DatasetAutoDisableLog)
DatasetAutoDisableLog.created_at >= start_date, .filter(
).all() DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
if dataset_auto_disable_logs: if dataset_auto_disable_logs:
return { return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs], "document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -873,7 +886,9 @@ class DocumentService:
@staticmethod @staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document: if document:
return document.position + 1 return document.position + 1
else: else:
@ -1010,13 +1025,17 @@ class DocumentService:
} }
# check duplicate # check duplicate
if knowledge_config.duplicate: if knowledge_config.duplicate:
document = Document.query.filter_by( document = (
dataset_id=dataset.id, db.session.query(Document)
tenant_id=current_user.current_tenant_id, .filter_by(
data_source_type="upload_file", dataset_id=dataset.id,
enabled=True, tenant_id=current_user.current_tenant_id,
name=file_name, data_source_type="upload_file",
).first() enabled=True,
name=file_name,
)
.first()
)
if document: if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -1054,12 +1073,16 @@ class DocumentService:
raise ValueError("No notion info list found.") raise ValueError("No notion info list found.")
exist_page_ids = [] exist_page_ids = []
exist_document = {} exist_document = {}
documents = Document.query.filter_by( documents = (
dataset_id=dataset.id, db.session.query(Document)
tenant_id=current_user.current_tenant_id, .filter_by(
data_source_type="notion_import", dataset_id=dataset.id,
enabled=True, tenant_id=current_user.current_tenant_id,
).all() data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents: if documents:
for document in documents: for document in documents:
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
@ -1206,12 +1229,16 @@ class DocumentService:
@staticmethod @staticmethod
def get_tenant_documents_count(): def get_tenant_documents_count():
documents_count = Document.query.filter( documents_count = (
Document.completed_at.isnot(None), db.session.query(Document)
Document.enabled == True, .filter(
Document.archived == False, Document.completed_at.isnot(None),
Document.tenant_id == current_user.current_tenant_id, Document.enabled == True,
).count() Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
.count()
)
return documents_count return documents_count
@staticmethod @staticmethod
@ -1328,7 +1355,7 @@ class DocumentService:
db.session.commit() db.session.commit()
# update document segment # update document segment
update_params = {DocumentSegment.status: "re_segment"} update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit() db.session.commit()
# trigger async task # trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id) document_indexing_update_task.delay(document.dataset_id, document.id)
@ -1918,7 +1945,8 @@ class SegmentService:
@classmethod @classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = ( index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id) db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter( .filter(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
@ -2157,12 +2185,16 @@ class SegmentService:
def get_child_chunks( def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
): ):
query = ChildChunk.query.filter_by( query = (
tenant_id=current_user.current_tenant_id, db.session.query(ChildChunk)
dataset_id=dataset_id, .filter_by(
document_id=document_id, tenant_id=current_user.current_tenant_id,
segment_id=segment_id, dataset_id=dataset_id,
).order_by(ChildChunk.position.asc()) document_id=document_id,
segment_id=segment_id,
)
.order_by(ChildChunk.position.asc())
)
if keyword: if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
@ -2170,7 +2202,11 @@ class SegmentService:
@classmethod @classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID.""" """Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first() result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None return result if isinstance(result, ChildChunk) else None
@classmethod @classmethod
@ -2184,7 +2220,7 @@ class SegmentService:
limit: int = 20, limit: int = 20,
): ):
"""Get segments for a document with optional filtering.""" """Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter( query = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
) )
@ -2236,9 +2272,11 @@ class SegmentService:
raise ValueError(ex.description) raise ValueError(ex.description)
# check segment # check segment
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -2251,9 +2289,11 @@ class SegmentService:
@classmethod @classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID.""" """Get a segment by its ID."""
result = DocumentSegment.query.filter( result = (
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None return result if isinstance(result, DocumentSegment) else None

@ -25,8 +25,10 @@ from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService: class ExternalDatasetService:
@staticmethod @staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( query = (
ExternalKnowledgeApis.created_at.desc() db.session.query(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
) )
if search: if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
@ -92,18 +94,18 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( external_knowledge_api: Optional[ExternalKnowledgeApis] = (
id=external_knowledge_api_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
return external_knowledge_api return external_knowledge_api
@staticmethod @staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( external_knowledge_api: Optional[ExternalKnowledgeApis] = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
@ -120,9 +122,9 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
@ -131,25 +133,29 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0: if count > 0:
return True, count return True, count
return False, 0 return False, 0
@staticmethod @staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
dataset_id=dataset_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
).first() )
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("external knowledge binding not found") raise ValueError("external knowledge binding not found")
return external_knowledge_binding return external_knowledge_binding
@staticmethod @staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings) settings = json.loads(external_knowledge_api.settings)
@ -212,11 +218,13 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists # check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id db.session.query(ExternalKnowledgeApis)
).first() .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
@ -254,15 +262,17 @@ class ExternalDatasetService:
external_retrieval_parameters: dict, external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None, metadata_condition: Optional[MetadataCondition] = None,
) -> list: ) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding = (
dataset_id=dataset_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
).first() )
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("external knowledge binding not found") raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_binding.external_knowledge_api_id db.session.query(ExternalKnowledgeApis)
).first() .filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api: if not external_knowledge_api:
raise ValueError("external api template not found") raise ValueError("external api template not found")

@ -20,9 +20,11 @@ class MetadataService:
@staticmethod @staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by( if (
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name db.session.query(DatasetMetadata)
).first(): .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == metadata_args.name: if field.value == metadata_args.name:
@ -42,16 +44,18 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by( if (
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name db.session.query(DatasetMetadata)
).first(): .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == name: if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.") raise ValueError("Metadata name already exists in Built-in fields.")
try: try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None: if metadata is None:
raise ValueError("Metadata not found.") raise ValueError("Metadata not found.")
old_name = metadata.name old_name = metadata.name
@ -60,7 +64,9 @@ class MetadataService:
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents # update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings: if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings] document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids) documents = DocumentService.get_document_by_ids(document_ids)
@ -82,13 +88,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
try: try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None: if metadata is None:
raise ValueError("Metadata not found.") raise ValueError("Metadata not found.")
db.session.delete(metadata) db.session.delete(metadata)
# deal related documents # deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings: if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings] document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids) documents = DocumentService.get_document_by_ids(document_ids)
@ -193,7 +201,7 @@ class MetadataService:
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
# deal metadata binding # deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list: for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -230,9 +238,9 @@ class MetadataService:
"id": item.get("id"), "id": item.get("id"),
"name": item.get("name"), "name": item.get("name"),
"type": item.get("type"), "type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by( "count": db.session.query(DatasetMetadataBinding)
metadata_id=item.get("id"), dataset_id=dataset.id .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
).count(), .count(),
} }
for item in dataset.doc_metadata or [] for item in dataset.doc_metadata or []
if item.get("id") != "built-in" if item.get("id") != "built-in"

@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit() db.session.commit()
document = Document( document = Document(
page_content=segment.content, page_content=segment.content,
@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit() db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()

@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")

Loading…
Cancel
Save