.filter( -> .where(

pull/22801/head
Asuka Minato 10 months ago committed by -LAN-
parent 451e593f37
commit f4bd3011d2
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -172,7 +172,7 @@ def migrate_annotation_vector_database():
per_page = 50 per_page = 50
apps = ( apps = (
db.session.query(App) db.session.query(App)
.filter(App.status == "normal") .where(App.status == "normal")
.order_by(App.created_at.desc()) .order_by(App.created_at.desc())
.limit(per_page) .limit(per_page)
.offset((page - 1) * per_page) .offset((page - 1) * per_page)
@ -202,7 +202,7 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first() .first()
) )
if not dataset_collection_binding: if not dataset_collection_binding:
@ -332,7 +332,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id) .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none() .one_or_none()
) )
if dataset_collection_binding: if dataset_collection_binding:
@ -367,7 +367,7 @@ def migrate_knowledge_vector_database():
dataset_documents = ( dataset_documents = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter( .where(
DatasetDocument.dataset_id == dataset.id, DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -381,7 +381,7 @@ def migrate_knowledge_vector_database():
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
@ -560,7 +560,7 @@ def old_metadata_migration():
try: try:
stmt = ( stmt = (
select(DatasetDocument) select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None)) .where(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc()) .order_by(DatasetDocument.created_at.desc())
) )
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -578,7 +578,7 @@ def old_metadata_migration():
else: else:
dataset_metadata = ( dataset_metadata = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
.filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first() .first()
) )
if not dataset_metadata: if not dataset_metadata:
@ -602,7 +602,7 @@ def old_metadata_migration():
else: else:
dataset_metadata_binding = ( dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore db.session.query(DatasetMetadataBinding) # type: ignore
.filter( .where(
DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id,

@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource):
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = ( keys = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all() .all()
) )
return {"items": keys} return {"items": keys}
@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource):
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count() .count()
) )
@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter( .where(
getattr(ApiToken, self.resource_id_field) == resource_id, getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,

@ -121,7 +121,7 @@ class CompletionConversationDetailApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -181,7 +181,7 @@ class ChatConversationApi(Resource):
Message.conversation_id == Conversation.id, Message.conversation_id == Conversation.id,
) )
.join(subquery, subquery.c.conversation_id == Conversation.id) .join(subquery, subquery.c.conversation_id == Conversation.id)
.filter( .where(
or_( or_(
Message.query.ilike(keyword_filter), Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter), Message.answer.ilike(keyword_filter),
@ -286,7 +286,7 @@ class ChatConversationDetailApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -308,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversati
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )

@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource):
raise NotFound() raise NotFound()
server = ( server = (
db.session.query(AppMCPServer) db.session.query(AppMCPServer)
.filter(AppMCPServer.id == server_id) .where(AppMCPServer.id == server_id)
.filter(AppMCPServer.tenant_id == current_user.current_tenant_id) .where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not server: if not server:

@ -56,7 +56,7 @@ class ChatMessageListApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first() .first()
) )
@ -66,7 +66,7 @@ class ChatMessageListApi(Resource):
if args["first_id"]: if args["first_id"]:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first() .first()
) )
@ -75,7 +75,7 @@ class ChatMessageListApi(Resource):
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at, Message.created_at < first_message.created_at,
Message.id != first_message.id, Message.id != first_message.id,
@ -87,7 +87,7 @@ class ChatMessageListApi(Resource):
else: else:
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args["limit"])
.all() .all()
@ -98,7 +98,7 @@ class ChatMessageListApi(Resource):
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
rest_count = ( rest_count = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at, Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id, Message.id != current_page_first_message.id,

@ -11,7 +11,7 @@ from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]: def _load_app_model(app_id: str) -> Optional[App]:
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
return app_model return app_model

@ -30,7 +30,7 @@ class DataSourceApi(Resource):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = ( data_source_integrates = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )

@ -412,7 +412,7 @@ class DatasetIndexingEstimateApi(Resource):
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = ( file_details = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all() .all()
) )
@ -517,14 +517,14 @@ class DatasetIndexingStatusApi(Resource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all() .all()
) )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -533,7 +533,7 @@ class DatasetIndexingStatusApi(Resource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -568,7 +568,7 @@ class DatasetApiKeyApi(Resource):
def get(self): def get(self):
keys = ( keys = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all() .all()
) )
return {"items": keys} return {"items": keys}
@ -584,7 +584,7 @@ class DatasetApiKeyApi(Resource):
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count() .count()
) )
@ -620,7 +620,7 @@ class DatasetApiDeleteApi(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter( .where(
ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,

@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource):
# get the latest process rule # get the latest process rule
dataset_process_rule = ( dataset_process_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == document.dataset_id) .where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.limit(1) .limit(1)
.one_or_none() .one_or_none()
@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource):
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.where(Document.name.like(search))
if sort.startswith("-"): if sort.startswith("-"):
sort_logic = desc sort_logic = desc
@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments document.completed_segments = completed_segments
@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource):
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count() .count()
) )

@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = ( query = (
select(DocumentSegment) select(DocumentSegment)
.filter( .where(
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_user.current_tenant_id,
) )
@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource):
) )
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.where(DocumentSegment.status.in_(status_list))
if hit_count_gte is not None: if hit_count_gte is not None:
query = query.filter(DocumentSegment.hit_count >= hit_count_gte) query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all": if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true": if args["enabled"].lower() == "true":
query = query.filter(DocumentSegment.enabled == True) query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not child_chunk: if not child_chunk:
@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not child_chunk: if not child_chunk:

@ -34,7 +34,7 @@ class InstalledAppsListApi(Resource):
if app_id: if app_id:
installed_apps = ( installed_apps = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all() .all()
) )
else: else:
@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource):
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first() .first()
) )

@ -28,7 +28,7 @@ def installed_app_required(view=None):
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.filter( .where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
) )
.first() .first()

@ -21,7 +21,7 @@ def plugin_permission_required(
with Session(db.engine) as session: with Session(db.engine) as session:
permission = ( permission = (
session.query(TenantPluginPermission) session.query(TenantPluginPermission)
.filter( .where(
TenantPluginPermission.tenant_id == tenant_id, TenantPluginPermission.tenant_id == tenant_id,
) )
.first() .first()

@ -68,7 +68,7 @@ class AccountInitApi(Resource):
# check invitation code # check invitation code
invitation_code = ( invitation_code = (
db.session.query(InvitationCode) db.session.query(InvitationCode)
.filter( .where(
InvitationCode.code == args["invitation_code"], InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused", InvitationCode.status == "unused",
) )

@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None):
try: try:
tenant_model = ( tenant_model = (
db.session.query(Tenant) db.session.query(Tenant)
.filter( .where(
Tenant.id == tenant_id, Tenant.id == tenant_id,
) )
.first() .first()

@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource):
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position)) query = query.order_by(desc(Document.created_at), desc(Document.position))
@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields

@ -62,10 +62,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id) .where(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"])) .where(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL) .where(Tenant.status == TenantStatus.NORMAL)
.one_or_none() .one_or_none()
) # TODO: only owner information is required, so only one is returned. ) # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
@ -213,10 +213,10 @@ def validate_dataset_token(view=None):
api_token = validate_and_get_api_token("dataset") api_token = validate_and_get_api_token("dataset")
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id) .where(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"])) .where(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL) .where(Tenant.status == TenantStatus.NORMAL)
.one_or_none() .one_or_none()
) # TODO: only owner information is required, so only one is returned. ) # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
@ -293,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
end_user = ( end_user = (
db.session.query(EndUser) db.session.query(EndUser)
.filter( .where(
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
EndUser.session_id == user_id, EndUser.session_id == user_id,

@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner):
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.filter( .where(
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
) )
.count() .count()

@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
""" """
message = ( message = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.id == message_id, Message.id == message_id,
Message.app_id == app_model.id, Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"), Message.from_source == ("api" if isinstance(user, EndUser) else "console"),

@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if conversation: if conversation:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first() .first()
) )

@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter( .where(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
@ -69,7 +69,7 @@ class DatasetIndexToolCallbackHandler:
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id) .where(DocumentSegment.id == child_chunk.segment_id)
.update( .update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
) )
@ -80,7 +80,7 @@ class DatasetIndexToolCallbackHandler:
) )
if "dataset_id" in document.metadata: if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment # add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)

@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):
provider_record = ( provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .where(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names), Provider.provider_name.in_(provider_names),
@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):
provider_model_record = ( provider_model_record = (
db.session.query(ProviderModel) db.session.query(ProviderModel)
.filter( .where(
ProviderModel.tenant_id == self.tenant_id, ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names), ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model, ProviderModel.model_name == model,
@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):
return ( return (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.filter( .where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):
return ( return (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):
load_balancing_config_count = ( load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ( model_setting = (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.filter( .where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = ( preferred_model_provider = (
db.session.query(TenantPreferredModelProvider) db.session.query(TenantPreferredModelProvider)
.filter( .where(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names), TenantPreferredModelProvider.provider_name.in_(provider_names),
) )

@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension # get api_based_extension
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )
@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension # get api_based_extension
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -59,7 +59,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -124,7 +124,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -212,7 +212,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )

@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
def retrieve_end_user(self): def retrieve_end_user(self):
return ( return (
db.session.query(EndUser) db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )

@ -89,7 +89,7 @@ class ApiModeration(Moderation):
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = ( extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -703,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata, WorkflowNodeExecutionModel.execution_metadata,
) )
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all() .all()
) )
return workflow_nodes return workflow_nodes

@ -218,7 +218,7 @@ class OpsTraceManager:
""" """
trace_config_data: Optional[TraceAppConfig] = ( trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )
@ -304,7 +304,7 @@ class OpsTraceManager:
if conversation_data.app_model_config_id: if conversation_data.app_model_config_id:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation_data.app_model_config_id) .where(AppModelConfig.id == conversation_data.app_model_config_id)
.first() .first()
) )
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:

@ -275,7 +275,7 @@ class ProviderManager:
# Get the corresponding TenantDefaultModel record # Get the corresponding TenantDefaultModel record
default_model = ( default_model = (
db.session.query(TenantDefaultModel) db.session.query(TenantDefaultModel)
.filter( .where(
TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(), TenantDefaultModel.model_type == model_type.to_origin_model_type(),
) )
@ -367,7 +367,7 @@ class ProviderManager:
# Get the list of available models from get_configurations and check if it is LLM # Get the list of available models from get_configurations and check if it is LLM
default_model = ( default_model = (
db.session.query(TenantDefaultModel) db.session.query(TenantDefaultModel)
.filter( .where(
TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(), TenantDefaultModel.model_type == model_type.to_origin_model_type(),
) )
@ -541,7 +541,7 @@ class ProviderManager:
db.session.rollback() db.session.rollback()
existed_provider_record = ( existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .where(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value, Provider.provider_type == ProviderType.SYSTEM.value,

@ -97,7 +97,7 @@ class Jieba(BaseKeyword):
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
) )
if document_ids_filter: if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first() segment = segment_query.first()
if segment: if segment:
@ -214,7 +214,7 @@ class Jieba(BaseKeyword):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = ( document_segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first() .first()
) )
if document_segment: if document_segment:

@ -294,7 +294,7 @@ class RetrievalService:
dataset_documents = { dataset_documents = {
doc.id: doc doc.id: doc
for doc in db.session.query(DatasetDocument) for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids)) .where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all() .all()
} }
@ -326,7 +326,7 @@ class RetrievalService:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
@ -381,7 +381,7 @@ class RetrievalService:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",

@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id) .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none() .one_or_none()
) )
if dataset_collection_binding: if dataset_collection_binding:

@ -424,7 +424,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = ( tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id) .where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none() .one_or_none()
) )
if tidb_auth_binding: if tidb_auth_binding:
@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else: else:
idle_tidb_auth_binding = ( idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1) .limit(1)
.one_or_none() .one_or_none()
) )

@ -47,7 +47,7 @@ class Vector:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE: if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = ( whitelist = (
db.session.query(Whitelist) db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none() .one_or_none()
) )
if whitelist: if whitelist:

@ -63,7 +63,7 @@ class DatasetDocumentStore:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id) .where(DocumentSegment.document_id == self._document_id)
.scalar() .scalar()
) )
@ -230,7 +230,7 @@ class DatasetDocumentStore:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = ( document_segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first() .first()
) )

@ -366,7 +366,7 @@ class NotionExtractor(BaseExtractor):
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == tenant_id, DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",

@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = ( child_node_ids = (
db.session.query(ChildChunk.index_node_id) db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter( .where(
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids), DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,

@ -242,7 +242,7 @@ class DatasetRetrieval:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = ( document = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter( .where(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
@ -516,14 +516,14 @@ class DatasetRetrieval:
if document.metadata is not None: if document.metadata is not None:
dataset_document = ( dataset_document = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"]) .where(DatasetDocument.id == document.metadata["document_id"])
.first() .first()
) )
if dataset_document: if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter( .where(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
@ -533,7 +533,7 @@ class DatasetRetrieval:
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id) .where(DocumentSegment.id == child_chunk.segment_id)
.update( .update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False, synchronize_session=False,
@ -547,7 +547,7 @@ class DatasetRetrieval:
# if 'dataset_id' in document.metadata: # if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata: if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment # add hit count to document segment
query.update( query.update(
@ -930,9 +930,9 @@ class DatasetRetrieval:
raise ValueError("Invalid metadata filtering mode") raise ValueError("Invalid metadata filtering mode")
if filters: if filters:
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.filter(or_(*filters)) document_query = document_query.where(or_(*filters))
documents = document_query.all() documents = document_query.all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore

@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers # get tenant api providers
db_providers: list[ApiToolProvider] = ( db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all() .all()
) )

@ -160,7 +160,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = ( tool_file: ToolFile | None = (
session.query(ToolFile) session.query(ToolFile)
.filter( .where(
ToolFile.id == id, ToolFile.id == id,
) )
.first() .first()
@ -184,7 +184,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = ( message_file: MessageFile | None = (
session.query(MessageFile) session.query(MessageFile)
.filter( .where(
MessageFile.id == id, MessageFile.id == id,
) )
.first() .first()
@ -204,7 +204,7 @@ class ToolFileManager:
tool_file: ToolFile | None = ( tool_file: ToolFile | None = (
session.query(ToolFile) session.query(ToolFile)
.filter( .where(
ToolFile.id == tool_file_id, ToolFile.id == tool_file_id,
) )
.first() .first()
@ -228,7 +228,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = ( tool_file: ToolFile | None = (
session.query(ToolFile) session.query(ToolFile)
.filter( .where(
ToolFile.id == tool_file_id, ToolFile.id == tool_file_id,
) )
.first() .first()

@ -57,7 +57,7 @@ class ToolLabelManager:
labels = ( labels = (
db.session.query(ToolLabelBinding.label_name) db.session.query(ToolLabelBinding.label_name)
.filter( .where(
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value, ToolLabelBinding.tool_type == controller.provider_type.value,
) )

@ -198,7 +198,7 @@ class ToolManager:
try: try:
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id, BuiltinToolProvider.id == credential_id,
) )
@ -216,7 +216,7 @@ class ToolManager:
# use the default provider # use the default provider
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity)) (BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name), | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
@ -229,7 +229,7 @@ class ToolManager:
else: else:
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first() .first()
) )
@ -316,7 +316,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = ( workflow_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first() .first()
) )
@ -731,7 +731,7 @@ class ToolManager:
""" """
provider: ApiToolProvider | None = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.id == provider_id, ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
) )
@ -768,7 +768,7 @@ class ToolManager:
""" """
provider: MCPToolProvider | None = ( provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter( .where(
MCPToolProvider.server_identifier == provider_id, MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.tenant_id == tenant_id,
) )
@ -793,7 +793,7 @@ class ToolManager:
provider_name = provider provider_name = provider
provider_obj: ApiToolProvider | None = ( provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider, ApiToolProvider.name == provider,
) )
@ -885,7 +885,7 @@ class ToolManager:
try: try:
workflow_provider: WorkflowToolProvider | None = ( workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first() .first()
) )
@ -902,7 +902,7 @@ class ToolManager:
try: try:
api_provider: ApiToolProvider | None = ( api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first() .first()
) )
@ -919,7 +919,7 @@ class ToolManager:
try: try:
mcp_provider: MCPToolProvider | None = ( mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first() .first()
) )

@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = ( document = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,

@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = ( document = (
db.session.query(DatasetDocument) # type: ignore db.session.query(DatasetDocument) # type: ignore
.filter( .where(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,

@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
""" """
workflow: Workflow | None = ( workflow: Workflow | None = (
db.session.query(Workflow) db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first() .first()
) )
@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
db_providers: WorkflowToolProvider | None = ( db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter( .where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.app_id == self.provider_id,
) )

@ -142,7 +142,7 @@ class WorkflowTool(Tool):
if not version: if not version:
workflow = ( workflow = (
db.session.query(Workflow) db.session.query(Workflow)
.filter(Workflow.app_id == app_id, Workflow.version != "draft") .where(Workflow.app_id == app_id, Workflow.version != "draft")
.order_by(Workflow.created_at.desc()) .order_by(Workflow.created_at.desc())
.first() .first()
) )

@ -228,7 +228,7 @@ class KnowledgeRetrievalNode(BaseNode):
# Subquery: Count the number of available documents for each dataset # Subquery: Count the number of available documents for each dataset
subquery = ( subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.filter( .where(
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -242,8 +242,8 @@ class KnowledgeRetrievalNode(BaseNode):
results = ( results = (
db.session.query(Dataset) db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) .where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all() .all()
) )
@ -370,7 +370,7 @@ class KnowledgeRetrievalNode(BaseNode):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = ( document = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -493,9 +493,9 @@ class KnowledgeRetrievalNode(BaseNode):
node_data.metadata_filtering_conditions node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and" and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore ): # type: ignore
document_query = document_query.filter(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.filter(or_(*filters)) document_query = document_query.where(or_(*filters))
documents = document_query.all() documents = document_query.all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore

@ -22,7 +22,7 @@ def handle(sender, **kwargs):
document = ( document = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id == document_id, Document.id == document_id,
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
) )

@ -40,9 +40,9 @@ def load_user_from_request(request_from_flask_login):
if workspace_id: if workspace_id:
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == workspace_id) .where(Tenant.id == workspace_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role == "owner") .where(TenantAccountJoin.role == "owner")
.one_or_none() .one_or_none()
) )
if tenant_account_join: if tenant_account_join:
@ -83,7 +83,7 @@ def load_user_from_request(request_from_flask_login):
raise NotFound("App MCP server not found.") raise NotFound("App MCP server not found.")
end_user = ( end_user = (
db.session.query(EndUser) db.session.query(EndUser)
.filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )
if not end_user: if not end_user:

@ -137,9 +137,9 @@ class Account(UserMixin, Base):
tuple[Tenant, TenantAccountJoin], tuple[Tenant, TenantAccountJoin],
( (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == tenant_id) .where(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id) .where(TenantAccountJoin.account_id == self.id)
.one_or_none() .one_or_none()
), ),
) )
@ -163,7 +163,7 @@ class Account(UserMixin, Base):
def get_by_openid(cls, provider: str, open_id: str): def get_by_openid(cls, provider: str, open_id: str):
account_integrate = ( account_integrate = (
db.session.query(AccountIntegrate) db.session.query(AccountIntegrate)
.filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none() .one_or_none()
) )
if account_integrate: if account_integrate:
@ -213,7 +213,7 @@ class Tenant(Base):
def get_accounts(self) -> list[Account]: def get_accounts(self) -> list[Account]:
return ( return (
db.session.query(Account) db.session.query(Account)
.filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all() .all()
) )

@ -95,7 +95,7 @@ class Dataset(Base):
def latest_process_rule(self): def latest_process_rule(self):
return ( return (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id) .where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.first() .first()
) )
@ -104,7 +104,7 @@ class Dataset(Base):
def app_count(self): def app_count(self):
return ( return (
db.session.query(func.count(AppDatasetJoin.id)) db.session.query(func.count(AppDatasetJoin.id))
.filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar() .scalar()
) )
@ -116,7 +116,7 @@ class Dataset(Base):
def available_document_count(self): def available_document_count(self):
return ( return (
db.session.query(func.count(Document.id)) db.session.query(func.count(Document.id))
.filter( .where(
Document.dataset_id == self.id, Document.dataset_id == self.id,
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
@ -129,7 +129,7 @@ class Dataset(Base):
def available_segment_count(self): def available_segment_count(self):
return ( return (
db.session.query(func.count(DocumentSegment.id)) db.session.query(func.count(DocumentSegment.id))
.filter( .where(
DocumentSegment.dataset_id == self.id, DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
@ -142,7 +142,7 @@ class Dataset(Base):
return ( return (
db.session.query(Document) db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0)) .with_entities(func.coalesce(func.sum(Document.word_count), 0))
.filter(Document.dataset_id == self.id) .where(Document.dataset_id == self.id)
.scalar() .scalar()
) )
@ -169,7 +169,7 @@ class Dataset(Base):
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
.filter( .where(
TagBinding.target_id == self.id, TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id, TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id,
@ -191,7 +191,7 @@ class Dataset(Base):
return None return None
external_knowledge_api = db.session.scalars( external_knowledge_api = db.session.scalars(
select(ExternalKnowledgeApis) select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) .where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.limit(1) .limit(1)
).first() ).first()
if not external_knowledge_api: if not external_knowledge_api:
@ -408,7 +408,7 @@ class Document(Base):
data_source_info_dict = json.loads(self.data_source_info) data_source_info_dict = json.loads(self.data_source_info)
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.id == data_source_info_dict["upload_file_id"]) .where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none() .one_or_none()
) )
if file_detail: if file_detail:
@ -452,7 +452,7 @@ class Document(Base):
return ( return (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.filter(DocumentSegment.document_id == self.id) .where(DocumentSegment.document_id == self.id)
.scalar() .scalar()
) )
@ -475,7 +475,7 @@ class Document(Base):
document_metadatas = ( document_metadatas = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.filter( .where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
) )
.all() .all()
@ -697,7 +697,7 @@ class DocumentSegment(Base):
def previous_segment(self): def previous_segment(self):
return db.session.scalars( return db.session.scalars(
select(DocumentSegment) select(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) .where(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
.limit(1) .limit(1)
).first() ).first()
@ -705,7 +705,7 @@ class DocumentSegment(Base):
def next_segment(self): def next_segment(self):
return db.session.scalars( return db.session.scalars(
select(DocumentSegment) select(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) .where(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
.limit(1) .limit(1)
).first() ).first()
@ -717,7 +717,7 @@ class DocumentSegment(Base):
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = ( child_chunks = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id) .where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc()) .order_by(ChildChunk.position.asc())
.all() .all()
) )
@ -734,7 +734,7 @@ class DocumentSegment(Base):
if rules.parent_mode: if rules.parent_mode:
child_chunks = ( child_chunks = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id) .where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc()) .order_by(ChildChunk.position.asc())
.all() .all()
) )
@ -1044,7 +1044,7 @@ class ExternalKnowledgeApis(Base):
def dataset_bindings(self): def dataset_bindings(self):
external_knowledge_bindings = ( external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings) db.session.query(ExternalKnowledgeBindings)
.filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all() .all()
) )
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]

@ -282,7 +282,7 @@ class App(Base):
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
.filter( .where(
TagBinding.target_id == self.id, TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id, TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id,
@ -751,7 +751,7 @@ class Conversation(Base):
def user_feedback_stats(self): def user_feedback_stats(self):
like = ( like = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user", MessageFeedback.from_source == "user",
MessageFeedback.rating == "like", MessageFeedback.rating == "like",
@ -761,7 +761,7 @@ class Conversation(Base):
dislike = ( dislike = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user", MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike", MessageFeedback.rating == "dislike",
@ -775,7 +775,7 @@ class Conversation(Base):
def admin_feedback_stats(self): def admin_feedback_stats(self):
like = ( like = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin", MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like", MessageFeedback.rating == "like",
@ -785,7 +785,7 @@ class Conversation(Base):
dislike = ( dislike = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin", MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike", MessageFeedback.rating == "dislike",
@ -824,7 +824,7 @@ class Conversation(Base):
def first_message(self): def first_message(self):
return ( return (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == self.id) .where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc()) .order_by(Message.created_at.asc())
.first() .first()
) )
@ -1040,7 +1040,7 @@ class Message(Base):
def user_feedback(self): def user_feedback(self):
feedback = ( feedback = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first() .first()
) )
return feedback return feedback
@ -1049,7 +1049,7 @@ class Message(Base):
def admin_feedback(self): def admin_feedback(self):
feedback = ( feedback = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first() .first()
) )
return feedback return feedback
@ -1072,7 +1072,7 @@ class Message(Base):
if annotation_history: if annotation_history:
annotation = ( annotation = (
db.session.query(MessageAnnotation) db.session.query(MessageAnnotation)
.filter(MessageAnnotation.id == annotation_history.annotation_id) .where(MessageAnnotation.id == annotation_history.annotation_id)
.first() .first()
) )
return annotation return annotation
@ -1082,9 +1082,7 @@ class Message(Base):
def app_model_config(self): def app_model_config(self):
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first() conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
if conversation: if conversation:
return ( return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
)
return None return None
@ -1100,7 +1098,7 @@ class Message(Base):
def agent_thoughts(self): def agent_thoughts(self):
return ( return (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == self.id) .where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc()) .order_by(MessageAgentThought.position.asc())
.all() .all()
) )
@ -1371,7 +1369,7 @@ class AppAnnotationHitHistory(Base):
account = ( account = (
db.session.query(Account) db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.filter(MessageAnnotation.id == self.annotation_id) .where(MessageAnnotation.id == self.annotation_id)
.first() .first()
) )
return account return account
@ -1404,7 +1402,7 @@ class AppAnnotationSetting(Base):
collection_binding_detail = ( collection_binding_detail = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == self.collection_binding_id) .where(DatasetCollectionBinding.id == self.collection_binding_id)
.first() .first()
) )
return collection_binding_detail return collection_binding_detail

@ -343,7 +343,7 @@ class Workflow(Base):
return ( return (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.count() .count()
> 0 > 0
) )

@ -21,7 +21,7 @@ def clean_embedding_cache_task():
try: try:
embedding_ids = ( embedding_ids = (
db.session.query(Embedding.id) db.session.query(Embedding.id)
.filter(Embedding.created_at < thirty_days_ago) .where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc()) .order_by(Embedding.created_at.desc())
.limit(100) .limit(100)
.all() .all()

@ -36,7 +36,7 @@ def clean_messages():
# Main query with join and filter # Main query with join and filter
messages = ( messages = (
db.session.query(Message) db.session.query(Message)
.filter(Message.created_at < plan_sandbox_clean_message_day) .where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(100) .limit(100)
.all() .all()

@ -27,7 +27,7 @@ def clean_unused_datasets_task():
# Subquery for counting new documents # Subquery for counting new documents
document_subquery_new = ( document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter( .where(
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -40,7 +40,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents # Subquery for counting old documents
document_subquery_old = ( document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter( .where(
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -55,7 +55,7 @@ def clean_unused_datasets_task():
select(Dataset) select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .where(
Dataset.created_at < plan_sandbox_clean_day, Dataset.created_at < plan_sandbox_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0,
@ -72,7 +72,7 @@ def clean_unused_datasets_task():
for dataset in datasets: for dataset in datasets:
dataset_query = ( dataset_query = (
db.session.query(DatasetQuery) db.session.query(DatasetQuery)
.filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) .where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
.all() .all()
) )
if not dataset_query or len(dataset_query) == 0: if not dataset_query or len(dataset_query) == 0:
@ -80,7 +80,7 @@ def clean_unused_datasets_task():
# add auto disable log # add auto disable log
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.dataset_id == dataset.id, Document.dataset_id == dataset.id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -111,7 +111,7 @@ def clean_unused_datasets_task():
# Subquery for counting new documents # Subquery for counting new documents
document_subquery_new = ( document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter( .where(
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -124,7 +124,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents # Subquery for counting old documents
document_subquery_old = ( document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter( .where(
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -139,7 +139,7 @@ def clean_unused_datasets_task():
select(Dataset) select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .where(
Dataset.created_at < plan_pro_clean_day, Dataset.created_at < plan_pro_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0,
@ -155,7 +155,7 @@ def clean_unused_datasets_task():
for dataset in datasets: for dataset in datasets:
dataset_query = ( dataset_query = (
db.session.query(DatasetQuery) db.session.query(DatasetQuery)
.filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) .where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
.all() .all()
) )
if not dataset_query or len(dataset_query) == 0: if not dataset_query or len(dataset_query) == 0:

@ -17,7 +17,7 @@ def update_tidb_serverless_status_task():
# check the number of idle tidb serverless # check the number of idle tidb serverless
tidb_serverless_list = ( tidb_serverless_list = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all() .all()
) )
if len(tidb_serverless_list) == 0: if len(tidb_serverless_list) == 0:

@ -900,7 +900,7 @@ class TenantService:
return ( return (
db.session.query(Tenant) db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.all() .all()
) )
@ -929,7 +929,7 @@ class TenantService:
tenant_account_join = ( tenant_account_join = (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.filter( .where(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL, Tenant.status == TenantStatus.NORMAL,
@ -955,7 +955,7 @@ class TenantService:
db.session.query(Account, TenantAccountJoin.role) db.session.query(Account, TenantAccountJoin.role)
.select_from(Account) .select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(TenantAccountJoin.tenant_id == tenant.id) .where(TenantAccountJoin.tenant_id == tenant.id)
) )
# Initialize an empty list to store the updated accounts # Initialize an empty list to store the updated accounts
@ -974,8 +974,8 @@ class TenantService:
db.session.query(Account, TenantAccountJoin.role) db.session.query(Account, TenantAccountJoin.role)
.select_from(Account) .select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(TenantAccountJoin.tenant_id == tenant.id) .where(TenantAccountJoin.tenant_id == tenant.id)
.filter(TenantAccountJoin.role == "dataset_operator") .where(TenantAccountJoin.role == "dataset_operator")
) )
# Initialize an empty list to store the updated accounts # Initialize an empty list to store the updated accounts
@ -995,9 +995,7 @@ class TenantService:
return ( return (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.filter( .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
)
.first() .first()
is not None is not None
) )
@ -1007,7 +1005,7 @@ class TenantService:
"""Get the role of the current account for a given tenant""" """Get the role of the current account for a given tenant"""
join = ( join = (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first() .first()
) )
return TenantAccountRole(join.role) if join else None return TenantAccountRole(join.role) if join else None
@ -1274,7 +1272,7 @@ class RegisterService:
tenant = ( tenant = (
db.session.query(Tenant) db.session.query(Tenant)
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") .where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first() .first()
) )
@ -1284,7 +1282,7 @@ class RegisterService:
tenant_account = ( tenant_account = (
db.session.query(Account, TenantAccountJoin.role) db.session.query(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) .where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first() .first()
) )

@ -25,7 +25,7 @@ class AgentService:
conversation: Conversation | None = ( conversation: Conversation | None = (
db.session.query(Conversation) db.session.query(Conversation)
.filter( .where(
Conversation.id == conversation_id, Conversation.id == conversation_id,
Conversation.app_id == app_model.id, Conversation.app_id == app_model.id,
) )
@ -37,7 +37,7 @@ class AgentService:
message: Optional[Message] = ( message: Optional[Message] = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.id == message_id, Message.id == message_id,
Message.conversation_id == conversation_id, Message.conversation_id == conversation_id,
) )
@ -55,9 +55,7 @@ class AgentService:
db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first() db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first()
) )
else: else:
executor = ( executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
)
if executor: if executor:
executor = executor.name executor = executor.name

@ -26,7 +26,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -61,9 +61,7 @@ class AppAnnotationService:
db.session.add(annotation) db.session.add(annotation)
db.session.commit() db.session.commit()
# if annotation reply is enabled , add annotation to index # if annotation reply is enabled , add annotation to index
annotation_setting = ( annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting: if annotation_setting:
add_annotation_to_index_task.delay( add_annotation_to_index_task.delay(
annotation.id, annotation.id,
@ -117,7 +115,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -126,8 +124,8 @@ class AppAnnotationService:
if keyword: if keyword:
stmt = ( stmt = (
select(MessageAnnotation) select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id) .where(MessageAnnotation.app_id == app_id)
.filter( .where(
or_( or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)), MessageAnnotation.question.ilike("%{}%".format(keyword)),
MessageAnnotation.content.ilike("%{}%".format(keyword)), MessageAnnotation.content.ilike("%{}%".format(keyword)),
@ -138,7 +136,7 @@ class AppAnnotationService:
else: else:
stmt = ( stmt = (
select(MessageAnnotation) select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id) .where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
) )
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -149,7 +147,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -157,7 +155,7 @@ class AppAnnotationService:
raise NotFound("App not found") raise NotFound("App not found")
annotations = ( annotations = (
db.session.query(MessageAnnotation) db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id) .where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc()) .order_by(MessageAnnotation.created_at.desc())
.all() .all()
) )
@ -168,7 +166,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -181,9 +179,7 @@ class AppAnnotationService:
db.session.add(annotation) db.session.add(annotation)
db.session.commit() db.session.commit()
# if annotation reply is enabled , add annotation to index # if annotation reply is enabled , add annotation to index
annotation_setting = ( annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting: if annotation_setting:
add_annotation_to_index_task.delay( add_annotation_to_index_task.delay(
annotation.id, annotation.id,
@ -199,7 +195,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -236,7 +232,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -252,7 +248,7 @@ class AppAnnotationService:
annotation_hit_histories = ( annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory) db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id) .where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all() .all()
) )
if annotation_hit_histories: if annotation_hit_histories:
@ -275,7 +271,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -314,7 +310,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -328,7 +324,7 @@ class AppAnnotationService:
stmt = ( stmt = (
select(AppAnnotationHitHistory) select(AppAnnotationHitHistory)
.filter( .where(
AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id, AppAnnotationHitHistory.annotation_id == annotation_id,
) )
@ -384,16 +380,14 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
if not app: if not app:
raise NotFound("App not found") raise NotFound("App not found")
annotation_setting = ( annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting: if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail collection_binding_detail = annotation_setting.collection_binding_detail
return { return {
@ -412,7 +406,7 @@ class AppAnnotationService:
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
@ -421,7 +415,7 @@ class AppAnnotationService:
annotation_setting = ( annotation_setting = (
db.session.query(AppAnnotationSetting) db.session.query(AppAnnotationSetting)
.filter( .where(
AppAnnotationSetting.app_id == app_id, AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id, AppAnnotationSetting.id == annotation_setting_id,
) )

@ -73,7 +73,7 @@ class APIBasedExtensionService:
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id) .filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name) .filter_by(name=extension_data.name)
.filter(APIBasedExtension.id != extension_data.id) .where(APIBasedExtension.id != extension_data.id)
.first() .first()
) )

@ -11,7 +11,7 @@ class ApiKeyAuthService:
def get_provider_auth_list(tenant_id: str) -> list: def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = ( data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding) db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all() .all()
) )
return data_source_api_key_bindings return data_source_api_key_bindings
@ -36,7 +36,7 @@ class ApiKeyAuthService:
def get_auth_credentials(tenant_id: str, category: str, provider: str): def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = ( data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding) db.session.query(DataSourceApiKeyAuthBinding)
.filter( .where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider, DataSourceApiKeyAuthBinding.provider == provider,
@ -53,7 +53,7 @@ class ApiKeyAuthService:
def delete_provider_auth(tenant_id: str, binding_id: str): def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = ( data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding) db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first() .first()
) )
if data_source_api_key_binding: if data_source_api_key_binding:

@ -75,7 +75,7 @@ class BillingService:
join: Optional[TenantAccountJoin] = ( join: Optional[TenantAccountJoin] = (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first() .first()
) )

@ -30,7 +30,7 @@ class ClearFreePlanTenantExpiredLogs:
with Session(db.engine).no_autoflush as session: with Session(db.engine).no_autoflush as session:
messages = ( messages = (
session.query(Message) session.query(Message)
.filter( .where(
Message.app_id.in_(app_ids), Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
) )
@ -70,7 +70,7 @@ class ClearFreePlanTenantExpiredLogs:
with Session(db.engine).no_autoflush as session: with Session(db.engine).no_autoflush as session:
conversations = ( conversations = (
session.query(Conversation) session.query(Conversation)
.filter( .where(
Conversation.app_id.in_(app_ids), Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
) )
@ -276,7 +276,7 @@ class ClearFreePlanTenantExpiredLogs:
for test_interval in test_intervals: for test_interval in test_intervals:
tenant_count = ( tenant_count = (
session.query(Tenant.id) session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, current_time + test_interval)) .where(Tenant.created_at.between(current_time, current_time + test_interval))
.count() .count()
) )
if tenant_count <= 100: if tenant_count <= 100:
@ -301,7 +301,7 @@ class ClearFreePlanTenantExpiredLogs:
rs = ( rs = (
session.query(Tenant.id) session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, batch_end)) .where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at) .order_by(Tenant.created_at)
) )

@ -123,7 +123,7 @@ class ConversationService:
# get conversation first message # get conversation first message
message = ( message = (
db.session.query(Message) db.session.query(Message)
.filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc()) .order_by(Message.created_at.asc())
.first() .first()
) )
@ -148,7 +148,7 @@ class ConversationService:
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter( .where(
Conversation.id == conversation_id, Conversation.id == conversation_id,
Conversation.app_id == app_model.id, Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),

@ -92,14 +92,14 @@ class DatasetService:
if user.current_role == TenantAccountRole.DATASET_OPERATOR: if user.current_role == TenantAccountRole.DATASET_OPERATOR:
# only show datasets that the user has permission to access # only show datasets that the user has permission to access
if permitted_dataset_ids: if permitted_dataset_ids:
query = query.filter(Dataset.id.in_(permitted_dataset_ids)) query = query.where(Dataset.id.in_(permitted_dataset_ids))
else: else:
return [], 0 return [], 0
else: else:
if user.current_role != TenantAccountRole.OWNER or not include_all: if user.current_role != TenantAccountRole.OWNER or not include_all:
# show all datasets that the user has permission to access # show all datasets that the user has permission to access
if permitted_dataset_ids: if permitted_dataset_ids:
query = query.filter( query = query.where(
db.or_( db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM, Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_( db.and_(
@ -112,7 +112,7 @@ class DatasetService:
) )
) )
else: else:
query = query.filter( query = query.where(
db.or_( db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM, Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_( db.and_(
@ -122,15 +122,15 @@ class DatasetService:
) )
else: else:
# if no user, only show datasets that are shared with all team members # if no user, only show datasets that are shared with all team members
query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search: if search:
query = query.filter(Dataset.name.ilike(f"%{search}%")) query = query.where(Dataset.name.ilike(f"%{search}%"))
if tag_ids: if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
if target_ids: if target_ids:
query = query.filter(Dataset.id.in_(target_ids)) query = query.where(Dataset.id.in_(target_ids))
else: else:
return [], 0 return [], 0
@ -143,7 +143,7 @@ class DatasetService:
# get the latest process rule # get the latest process rule
dataset_process_rule = ( dataset_process_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == dataset_id) .where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.limit(1) .limit(1)
.one_or_none() .one_or_none()
@ -697,7 +697,7 @@ class DatasetService:
def get_related_apps(dataset_id: str): def get_related_apps(dataset_id: str):
return ( return (
db.session.query(AppDatasetJoin) db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id) .where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at)) .order_by(db.desc(AppDatasetJoin.created_at))
.all() .all()
) )
@ -714,7 +714,7 @@ class DatasetService:
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = ( dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog) db.session.query(DatasetAutoDisableLog)
.filter( .where(
DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date, DatasetAutoDisableLog.created_at >= start_date,
) )
@ -859,7 +859,7 @@ class DocumentService:
def get_document_by_ids(document_ids: list[str]) -> list[Document]: def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id.in_(document_ids), Document.id.in_(document_ids),
Document.enabled == True, Document.enabled == True,
Document.indexing_status == "completed", Document.indexing_status == "completed",
@ -873,7 +873,7 @@ class DocumentService:
def get_document_by_dataset_id(dataset_id: str) -> list[Document]: def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
Document.enabled == True, Document.enabled == True,
) )
@ -886,7 +886,7 @@ class DocumentService:
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
Document.enabled == True, Document.enabled == True,
Document.indexing_status == "completed", Document.indexing_status == "completed",
@ -901,7 +901,7 @@ class DocumentService:
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all() .all()
) )
return documents return documents
@ -910,7 +910,7 @@ class DocumentService:
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.batch == batch, Document.batch == batch,
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id, Document.tenant_id == current_user.current_tenant_id,
@ -1189,7 +1189,7 @@ class DocumentService:
for file_id in upload_file_list: for file_id in upload_file_list:
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -1270,7 +1270,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
@ -1413,7 +1413,7 @@ class DocumentService:
def get_tenant_documents_count(): def get_tenant_documents_count():
documents_count = ( documents_count = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.completed_at.isnot(None), Document.completed_at.isnot(None),
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -1469,7 +1469,7 @@ class DocumentService:
for file_id in upload_file_list: for file_id in upload_file_list:
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -1489,7 +1489,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
@ -2005,7 +2005,7 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == document.id) .where(DocumentSegment.document_id == document.id)
.scalar() .scalar()
) )
segment_document = DocumentSegment( segment_document = DocumentSegment(
@ -2062,7 +2062,7 @@ class SegmentService:
) )
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == document.id) .where(DocumentSegment.document_id == document.id)
.scalar() .scalar()
) )
pre_segment_data_list = [] pre_segment_data_list = []
@ -2201,7 +2201,7 @@ class SegmentService:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id) .where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -2276,7 +2276,7 @@ class SegmentService:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id) .where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -2321,7 +2321,7 @@ class SegmentService:
index_node_ids = ( index_node_ids = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id) .with_entities(DocumentSegment.index_node_id)
.filter( .where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
@ -2340,7 +2340,7 @@ class SegmentService:
if action == "enable": if action == "enable":
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
@ -2367,7 +2367,7 @@ class SegmentService:
elif action == "disable": elif action == "disable":
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
@ -2404,7 +2404,7 @@ class SegmentService:
index_node_hash = helper.generate_text_hash(content) index_node_hash = helper.generate_text_hash(content)
child_chunk_count = ( child_chunk_count = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter( .where(
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id, ChildChunk.document_id == document.id,
@ -2414,7 +2414,7 @@ class SegmentService:
) )
max_position = ( max_position = (
db.session.query(func.max(ChildChunk.position)) db.session.query(func.max(ChildChunk.position))
.filter( .where(
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id, ChildChunk.document_id == document.id,
@ -2457,7 +2457,7 @@ class SegmentService:
) -> list[ChildChunk]: ) -> list[ChildChunk]:
child_chunks = ( child_chunks = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter( .where(
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id, ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
@ -2578,7 +2578,7 @@ class SegmentService:
"""Get a child chunk by its ID.""" """Get a child chunk by its ID."""
result = ( result = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) .where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first() .first()
) )
return result if isinstance(result, ChildChunk) else None return result if isinstance(result, ChildChunk) else None
@ -2599,10 +2599,10 @@ class SegmentService:
) )
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.where(DocumentSegment.status.in_(status_list))
if keyword: if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc()) query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -2647,7 +2647,7 @@ class SegmentService:
# check segment # check segment
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first() .first()
) )
if not segment: if not segment:
@ -2664,7 +2664,7 @@ class SegmentService:
"""Get a segment by its ID.""" """Get a segment by its ID."""
result = ( result = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first() .first()
) )
return result if isinstance(result, DocumentSegment) else None return result if isinstance(result, DocumentSegment) else None
@ -2677,7 +2677,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding: ) -> DatasetCollectionBinding:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .where(
DatasetCollectionBinding.provider_name == provider_name, DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name, DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type, DatasetCollectionBinding.type == collection_type,
@ -2703,7 +2703,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding: ) -> DatasetCollectionBinding:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
@ -2722,7 +2722,7 @@ class DatasetPermissionService:
db.session.query( db.session.query(
DatasetPermission.account_id, DatasetPermission.account_id,
) )
.filter(DatasetPermission.dataset_id == dataset_id) .where(DatasetPermission.dataset_id == dataset_id)
.all() .all()
) )

@ -30,11 +30,11 @@ class ExternalDatasetService:
) -> tuple[list[ExternalKnowledgeApis], int | None]: ) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = ( query = (
select(ExternalKnowledgeApis) select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id) .where(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc()) .order_by(ExternalKnowledgeApis.created_at.desc())
) )
if search: if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = db.paginate( external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False select=query, page=page, per_page=per_page, max_per_page=100, error_out=False

@ -50,7 +50,7 @@ class MessageService:
if first_id: if first_id:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == first_id) .where(Message.conversation_id == conversation.id, Message.id == first_id)
.first() .first()
) )
@ -59,7 +59,7 @@ class MessageService:
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at, Message.created_at < first_message.created_at,
Message.id != first_message.id, Message.id != first_message.id,
@ -71,7 +71,7 @@ class MessageService:
else: else:
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(fetch_limit) .limit(fetch_limit)
.all() .all()
@ -109,19 +109,19 @@ class MessageService:
app_model=app_model, user=user, conversation_id=conversation_id app_model=app_model, user=user, conversation_id=conversation_id
) )
base_query = base_query.filter(Message.conversation_id == conversation.id) base_query = base_query.where(Message.conversation_id == conversation.id)
if include_ids is not None: if include_ids is not None:
base_query = base_query.filter(Message.id.in_(include_ids)) base_query = base_query.where(Message.id.in_(include_ids))
if last_id: if last_id:
last_message = base_query.filter(Message.id == last_id).first() last_message = base_query.where(Message.id == last_id).first()
if not last_message: if not last_message:
raise LastMessageNotExistsError() raise LastMessageNotExistsError()
history_messages = ( history_messages = (
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(fetch_limit) .limit(fetch_limit)
.all() .all()
@ -183,7 +183,7 @@ class MessageService:
offset = (page - 1) * limit offset = (page - 1) * limit
feedbacks = ( feedbacks = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter(MessageFeedback.app_id == app_model.id) .where(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc()) .order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit) .limit(limit)
.offset(offset) .offset(offset)
@ -196,7 +196,7 @@ class MessageService:
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = ( message = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.id == message_id, Message.id == message_id,
Message.app_id == app_model.id, Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"), Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
@ -248,9 +248,7 @@ class MessageService:
if not conversation.override_model_configs: if not conversation.override_model_configs:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.filter( .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
.first() .first()
) )
else: else:

@ -103,7 +103,7 @@ class ModelLoadBalancingService:
# Get load balancing configurations # Get load balancing configurations
load_balancing_configs = ( load_balancing_configs = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@ -219,7 +219,7 @@ class ModelLoadBalancingService:
# Get load balancing configurations # Get load balancing configurations
load_balancing_model_config = ( load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@ -307,7 +307,7 @@ class ModelLoadBalancingService:
current_load_balancing_configs = ( current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@ -457,7 +457,7 @@ class ModelLoadBalancingService:
# Get load balancing config # Get load balancing config
load_balancing_model_config = ( load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider, LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),

@ -17,7 +17,7 @@ class OpsService:
""" """
trace_config_data: Optional[TraceAppConfig] = ( trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )
@ -148,7 +148,7 @@ class OpsService:
# check if trace config already exists # check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = ( trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )
@ -190,7 +190,7 @@ class OpsService:
# check if trace config already exists # check if trace config already exists
current_trace_config = ( current_trace_config = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )
@ -227,7 +227,7 @@ class OpsService:
""" """
trace_config = ( trace_config = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )

@ -101,7 +101,7 @@ class PluginMigration:
for test_interval in test_intervals: for test_interval in test_intervals:
tenant_count = ( tenant_count = (
session.query(Tenant.id) session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, current_time + test_interval)) .where(Tenant.created_at.between(current_time, current_time + test_interval))
.count() .count()
) )
if tenant_count <= 100: if tenant_count <= 100:
@ -126,7 +126,7 @@ class PluginMigration:
rs = ( rs = (
session.query(Tenant.id) session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, batch_end)) .where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at) .order_by(Tenant.created_at)
) )

@ -51,7 +51,7 @@ class PluginParameterService:
with Session(db.engine) as session: with Session(db.engine) as session:
db_record = ( db_record = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider, BuiltinToolProvider.provider == provider,
) )

@ -33,14 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
""" """
recommended_apps = ( recommended_apps = (
db.session.query(RecommendedApp) db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) .where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all() .all()
) )
if len(recommended_apps) == 0: if len(recommended_apps) == 0:
recommended_apps = ( recommended_apps = (
db.session.query(RecommendedApp) db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all() .all()
) )
@ -83,7 +83,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
# is in public recommended list # is in public recommended list
recommended_app = ( recommended_app = (
db.session.query(RecommendedApp) db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.first() .first()
) )

@ -17,7 +17,7 @@ class SavedMessageService:
raise ValueError("User is required") raise ValueError("User is required")
saved_messages = ( saved_messages = (
db.session.query(SavedMessage) db.session.query(SavedMessage)
.filter( .where(
SavedMessage.app_id == app_model.id, SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id, SavedMessage.created_by == user.id,
@ -37,7 +37,7 @@ class SavedMessageService:
return return
saved_message = ( saved_message = (
db.session.query(SavedMessage) db.session.query(SavedMessage)
.filter( .where(
SavedMessage.app_id == app_model.id, SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id, SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
@ -67,7 +67,7 @@ class SavedMessageService:
return return
saved_message = ( saved_message = (
db.session.query(SavedMessage) db.session.query(SavedMessage)
.filter( .where(
SavedMessage.app_id == app_model.id, SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id, SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),

@ -16,10 +16,10 @@ class TagService:
query = ( query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
.filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
) )
if keyword: if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all() results: list = query.order_by(Tag.created_at.desc()).all()
return results return results
@ -28,7 +28,7 @@ class TagService:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all() .all()
) )
if not tags: if not tags:
@ -36,7 +36,7 @@ class TagService:
tag_ids = [tag.id for tag in tags] tag_ids = [tag.id for tag in tags]
tag_bindings = ( tag_bindings = (
db.session.query(TagBinding.target_id) db.session.query(TagBinding.target_id)
.filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all() .all()
) )
if not tag_bindings: if not tag_bindings:
@ -50,7 +50,7 @@ class TagService:
return [] return []
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all() .all()
) )
if not tags: if not tags:
@ -62,7 +62,7 @@ class TagService:
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
.filter( .where(
TagBinding.target_id == target_id, TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id, TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id, Tag.tenant_id == current_tenant_id,
@ -125,7 +125,7 @@ class TagService:
for tag_id in args["tag_ids"]: for tag_id in args["tag_ids"]:
tag_binding = ( tag_binding = (
db.session.query(TagBinding) db.session.query(TagBinding)
.filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.first() .first()
) )
if tag_binding: if tag_binding:
@ -146,7 +146,7 @@ class TagService:
# delete tag binding # delete tag binding
tag_bindings = ( tag_bindings = (
db.session.query(TagBinding) db.session.query(TagBinding)
.filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.first() .first()
) )
if tag_bindings: if tag_bindings:
@ -158,7 +158,7 @@ class TagService:
if type == "knowledge": if type == "knowledge":
dataset = ( dataset = (
db.session.query(Dataset) db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) .where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.first() .first()
) )
if not dataset: if not dataset:
@ -166,7 +166,7 @@ class TagService:
elif type == "app": elif type == "app":
app = ( app = (
db.session.query(App) db.session.query(App)
.filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) .where(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.first() .first()
) )
if not app: if not app:

@ -119,7 +119,7 @@ class ApiToolManageService:
# check if the provider exists # check if the provider exists
provider = ( provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
@ -210,7 +210,7 @@ class ApiToolManageService:
""" """
provider: ApiToolProvider | None = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
@ -257,7 +257,7 @@ class ApiToolManageService:
# check if the provider exists # check if the provider exists
provider = ( provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider, ApiToolProvider.name == original_provider,
) )
@ -326,7 +326,7 @@ class ApiToolManageService:
""" """
provider = ( provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
@ -376,7 +376,7 @@ class ApiToolManageService:
db_provider = ( db_provider = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )

@ -154,7 +154,7 @@ class BuiltinToolManageService:
# get if the provider exists # get if the provider exists
db_provider = ( db_provider = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id, BuiltinToolProvider.id == credential_id,
) )
@ -404,7 +404,7 @@ class BuiltinToolManageService:
with Session(db.engine) as session: with Session(db.engine) as session:
db_provider = ( db_provider = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id, BuiltinToolProvider.id == credential_id,
) )
@ -613,7 +613,7 @@ class BuiltinToolManageService:
if provider_id_entity.organization != "langgenius": if provider_id_entity.organization != "langgenius":
provider = ( provider = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name, BuiltinToolProvider.provider == full_provider_name,
) )
@ -626,7 +626,7 @@ class BuiltinToolManageService:
else: else:
provider = ( provider = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name) (BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name), | (BuiltinToolProvider.provider == full_provider_name),
@ -647,7 +647,7 @@ class BuiltinToolManageService:
# it's an old provider without organization # it's an old provider without organization
return ( return (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by( .order_by(
BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first BuiltinToolProvider.created_at.asc(), # oldest first

@ -31,7 +31,7 @@ class MCPToolManageService:
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = ( res = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first() .first()
) )
if not res: if not res:
@ -42,7 +42,7 @@ class MCPToolManageService:
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider: def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = ( res = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first() .first()
) )
if not res: if not res:
@ -63,7 +63,7 @@ class MCPToolManageService:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = ( existing_provider = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter( .where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.tenant_id == tenant_id,
or_( or_(
MCPToolProvider.name == name, MCPToolProvider.name == name,
@ -100,7 +100,7 @@ class MCPToolManageService:
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = ( mcp_providers = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id) .where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name) .order_by(MCPToolProvider.name)
.all() .all()
) )

@ -43,7 +43,7 @@ class WorkflowToolManageService:
# check if the name is unique # check if the name is unique
existing_workflow_tool_provider = ( existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter( .where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id # name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
@ -123,7 +123,7 @@ class WorkflowToolManageService:
# check if the name is unique # check if the name is unique
existing_workflow_tool_provider = ( existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter( .where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name, WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id, WorkflowToolProvider.id != workflow_tool_id,
@ -136,7 +136,7 @@ class WorkflowToolManageService:
workflow_tool_provider: WorkflowToolProvider | None = ( workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
) )
@ -243,7 +243,7 @@ class WorkflowToolManageService:
""" """
db_tool: WorkflowToolProvider | None = ( db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
) )
return cls._get_workflow_tool(tenant_id, db_tool) return cls._get_workflow_tool(tenant_id, db_tool)
@ -259,7 +259,7 @@ class WorkflowToolManageService:
""" """
db_tool: WorkflowToolProvider | None = ( db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first() .first()
) )
return cls._get_workflow_tool(tenant_id, db_tool) return cls._get_workflow_tool(tenant_id, db_tool)
@ -318,7 +318,7 @@ class WorkflowToolManageService:
""" """
db_tool: WorkflowToolProvider | None = ( db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
) )

@ -36,7 +36,7 @@ class VectorService:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:

@ -65,7 +65,7 @@ class WebConversationService:
return return
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db.session.query(PinnedConversation)
.filter( .where(
PinnedConversation.app_id == app_model.id, PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id, PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
@ -97,7 +97,7 @@ class WebConversationService:
return return
pinned_conversation = ( pinned_conversation = (
db.session.query(PinnedConversation) db.session.query(PinnedConversation)
.filter( .where(
PinnedConversation.app_id == app_model.id, PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id, PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),

@ -620,7 +620,7 @@ class WorkflowConverter:
""" """
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -328,7 +328,7 @@ class WorkflowDraftVariableService:
def delete_workflow_variables(self, app_id: str): def delete_workflow_variables(self, app_id: str):
( (
self._session.query(WorkflowDraftVariable) self._session.query(WorkflowDraftVariable)
.filter(WorkflowDraftVariable.app_id == app_id) .where(WorkflowDraftVariable.app_id == app_id)
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )
@ -379,7 +379,7 @@ class WorkflowDraftVariableService:
if conv_id is not None: if conv_id is not None:
conversation = ( conversation = (
self._session.query(Conversation) self._session.query(Conversation)
.filter( .where(
Conversation.id == conv_id, Conversation.id == conv_id,
Conversation.app_id == workflow.app_id, Conversation.app_id == workflow.app_id,
) )

@ -89,7 +89,7 @@ class WorkflowService:
def is_workflow_exist(self, app_model: App) -> bool: def is_workflow_exist(self, app_model: App) -> bool:
return ( return (
db.session.query(Workflow) db.session.query(Workflow)
.filter( .where(
Workflow.tenant_id == app_model.tenant_id, Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id, Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT, Workflow.version == Workflow.VERSION_DRAFT,
@ -104,7 +104,7 @@ class WorkflowService:
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow = ( workflow = (
db.session.query(Workflow) db.session.query(Workflow)
.filter( .where(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
) )
.first() .first()
@ -117,7 +117,7 @@ class WorkflowService:
# fetch published workflow by workflow_id # fetch published workflow by workflow_id
workflow = ( workflow = (
db.session.query(Workflow) db.session.query(Workflow)
.filter( .where(
Workflow.tenant_id == app_model.tenant_id, Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id, Workflow.app_id == app_model.id,
Workflow.id == workflow_id, Workflow.id == workflow_id,
@ -141,7 +141,7 @@ class WorkflowService:
# fetch published workflow by workflow_id # fetch published workflow by workflow_id
workflow = ( workflow = (
db.session.query(Workflow) db.session.query(Workflow)
.filter( .where(
Workflow.tenant_id == app_model.tenant_id, Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id, Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id, Workflow.id == app_model.workflow_id,
@ -658,7 +658,7 @@ class WorkflowService:
# Check if there's a tool provider using this specific workflow version # Check if there's a tool provider using this specific workflow version
tool_provider = ( tool_provider = (
session.query(WorkflowToolProvider) session.query(WorkflowToolProvider)
.filter( .where(
WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.tenant_id == workflow.tenant_id,
WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.app_id == workflow.app_id,
WorkflowToolProvider.version == workflow.version, WorkflowToolProvider.version == workflow.version,

@ -25,7 +25,7 @@ class WorkspaceService:
# Get role of user # Get role of user
tenant_account_join = ( tenant_account_join = (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first() .first()
) )
assert tenant_account_join is not None, "TenantAccountJoin not found" assert tenant_account_join is not None, "TenantAccountJoin not found"

@ -43,7 +43,7 @@ def add_document_to_index_task(dataset_document_id: str):
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == False, DocumentSegment.enabled == False,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
@ -86,9 +86,7 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor.load(dataset, documents) index_processor.load(dataset, documents)
# delete auto disable log # delete auto disable log
db.session.query(DatasetAutoDisableLog).where( db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
# update segment to enable # update segment to enable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(

@ -26,9 +26,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
db.session.close() db.session.close()
return return
app_annotation_setting = ( app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if not app_annotation_setting: if not app_annotation_setting:
logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red")) logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red"))

@ -46,9 +46,7 @@ def enable_annotation_reply_task(
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation" embedding_provider_name, embedding_model_name, "annotation"
) )
annotation_setting = ( annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting: if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id: if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = ( old_dataset_collection_binding = (

@ -81,7 +81,7 @@ def batch_create_segment_to_index_task(
segment_hash = helper.generate_text_hash(content) # type: ignore segment_hash = helper.generate_text_hash(content) # type: ignore
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id) .where(DocumentSegment.document_id == dataset_document.id)
.scalar() .scalar()
) )
segment_document = DocumentSegment( segment_document = DocumentSegment(

@ -102,7 +102,7 @@ def clean_dataset_task(
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
if not file: if not file:

@ -35,7 +35,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
elif action == "add": elif action == "add":
dataset_documents = ( dataset_documents = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter( .where(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -56,7 +56,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
# add from vector index # add from vector index
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
.all() .all()
) )
@ -88,7 +88,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
elif action == "update": elif action == "update":
dataset_documents = ( dataset_documents = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter( .where(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -113,7 +113,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
try: try:
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
.all() .all()
) )

@ -44,7 +44,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id, DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id, DocumentSegment.document_id == document_id,

@ -46,7 +46,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_edited_time = data_source_info["last_edited_time"] page_edited_time = data_source_info["last_edited_time"]
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",

@ -45,7 +45,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id, DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id, DocumentSegment.document_id == document_id,

@ -142,9 +142,9 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(annotation_hit_history_id: str): def del_annotation_hit_history(annotation_hit_history_id: str):
db.session.query(AppAnnotationHitHistory).where( db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
AppAnnotationHitHistory.id == annotation_hit_history_id synchronize_session=False
).delete(synchronize_session=False) )
_delete_records( _delete_records(
"""select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""",

Loading…
Cancel
Save