Merge branch 'feat/change-email-completed-notification' into deploy/dev

deploy/dev
JzoNg 7 months ago
commit 613f39c96a

@ -0,0 +1,27 @@
name: autofix.ci
on:
workflow_call:
pull_request:
push:
branches: [ "main" ]
permissions:
contents: read
jobs:
autofix:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
- run: |
cd api
uv sync --dev
# Fix lint errors
uv run ruff check --fix-only .
# Format code
uv run ruff format .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

@ -50,7 +50,7 @@ def reset_password(email, new_password, password_confirm):
click.echo(click.style("Passwords do not match.", fg="red"))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@ -89,7 +89,7 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style("New emails do not match.", fg="red"))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@ -136,8 +136,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(
@ -172,7 +172,7 @@ def migrate_annotation_vector_database():
per_page = 50
apps = (
db.session.query(App)
.filter(App.status == "normal")
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
@ -192,7 +192,7 @@ def migrate_annotation_vector_database():
try:
click.echo("Creating app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
@ -202,13 +202,13 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@ -305,7 +305,7 @@ def migrate_knowledge_vector_database():
while True:
try:
stmt = (
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -332,7 +332,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
@ -367,7 +367,7 @@ def migrate_knowledge_vector_database():
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@ -381,7 +381,7 @@ def migrate_knowledge_vector_database():
for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
@ -468,7 +468,7 @@ def convert_to_agent_apps():
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if app is not None:
apps.append(app)
@ -483,7 +483,7 @@ def convert_to_agent_apps():
db.session.commit()
# update conversation mode to agent
db.session.query(Conversation).filter(Conversation.app_id == app.id).update(
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value}
)
@ -560,7 +560,7 @@ def old_metadata_migration():
try:
stmt = (
select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None))
.where(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc())
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -578,7 +578,7 @@ def old_metadata_migration():
else:
dataset_metadata = (
db.session.query(DatasetMetadata)
.filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first()
)
if not dataset_metadata:
@ -602,7 +602,7 @@ def old_metadata_migration():
else:
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
.filter(
.where(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
@ -717,7 +717,7 @@ where sites.id is null limit 1000"""
continue
try:
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
print(f"App {app_id} not found")
continue

@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource):
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app:
@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id):
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)

@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource):
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
return {"items": keys}
@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource):
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)
@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource):
key = (
db.session.query(ApiToken)
.filter(
.where(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204

@ -49,7 +49,7 @@ class CompletionConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])),
@ -121,7 +121,7 @@ class CompletionConversationDetailApi(Resource):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
@ -181,7 +181,7 @@ class ChatConversationApi(Resource):
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
.filter(
.where(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
@ -286,7 +286,7 @@ class ChatConversationDetailApi(Resource):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
@ -308,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversati
def _get_conversation(app_model, conversation_id):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)

@ -26,7 +26,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server
@setup_required
@ -73,7 +73,7 @@ class AppMCPServerController(Resource):
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server:
raise NotFound()
@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server = (
db.session.query(AppMCPServer)
.filter(AppMCPServer.id == server_id)
.filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.first()
)
if not server:

@ -56,7 +56,7 @@ class ChatMessageListApi(Resource):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
@ -66,7 +66,7 @@ class ChatMessageListApi(Resource):
if args["first_id"]:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
@ -75,7 +75,7 @@ class ChatMessageListApi(Resource):
history_messages = (
db.session.query(Message)
.filter(
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
@ -87,7 +87,7 @@ class ChatMessageListApi(Resource):
else:
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
@ -98,7 +98,7 @@ class ChatMessageListApi(Resource):
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.filter(
.where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
@ -167,7 +167,7 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required
@get_app_model
def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
return {"count": count}
@ -214,7 +214,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id):
message_id = str(message_id)
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

@ -42,7 +42,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
if original_app_model_config is None:
raise ValueError("Original app model config not found")

@ -49,7 +49,7 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
@ -93,7 +93,7 @@ class AppSiteAccessTokenReset(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound

@ -11,7 +11,7 @@ from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model

@ -30,7 +30,7 @@ class DataSourceApi(Resource):
# get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource):
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter(
select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",

@ -412,7 +412,7 @@ class DatasetIndexingEstimateApi(Resource):
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
@ -517,14 +517,14 @@ class DatasetIndexingStatusApi(Resource):
dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@ -533,7 +533,7 @@ class DatasetIndexingStatusApi(Resource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
@ -568,7 +568,7 @@ class DatasetApiKeyApi(Resource):
def get(self):
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
return {"items": keys}
@ -584,7 +584,7 @@ class DatasetApiKeyApi(Resource):
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count()
)
@ -620,7 +620,7 @@ class DatasetApiDeleteApi(Resource):
key = (
db.session.query(ApiToken)
.filter(
.where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
@ -631,7 +631,7 @@ class DatasetApiDeleteApi(Resource):
if key is None:
flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204

@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource):
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == document.dataset_id)
.where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource):
if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.where(Document.name.like(search))
if sort.startswith("-"):
sort_logic = desc
@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.first()
)
@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource):
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)

@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = (
select(DocumentSegment)
.filter(
.where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource):
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
query = query.where(DocumentSegment.status.in_(status_list))
if hit_count_gte is not None:
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
query = query.filter(DocumentSegment.enabled == True)
query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:

@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource):
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [
@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound("App not found")
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).filter(App.id == args["app_id"]).first()
app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None:
raise NotFound("App not found")
@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource):
installed_app = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first()
)

@ -28,7 +28,7 @@ def installed_app_required(view=None):
installed_app = (
db.session.query(InstalledApp)
.filter(
.where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first()

@ -21,7 +21,7 @@ def plugin_permission_required(
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
.where(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()

@ -68,7 +68,7 @@ class AccountInitApi(Resource):
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.filter(
.where(
InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused",
)
@ -228,7 +228,7 @@ class AccountIntegrateApi(Resource):
def get(self):
account = current_user
account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login"
@ -494,6 +494,10 @@ class ChangeEmailResetApi(Resource):
updated_account = AccountService.update_account(current_user, email=args["new_email"])
AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
)
return updated_account

@ -108,7 +108,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None:
abort(404)
else:

@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
user_model = session.query(EndUser).where(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None):
try:
tenant_model = (
db.session.query(Tenant)
.filter(
.where(
Tenant.id == tenant_id,
)
.first()

@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view):
if signature_base64 != token:
return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
return view(*args, **kwargs)

@ -30,7 +30,7 @@ class MCPAppApi(Resource):
request_id = args.get("id")
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
@ -41,7 +41,7 @@ class MCPAppApi(Resource):
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
app = db.session.query(App).filter(App.id == server.app_id).first()
app = db.session.query(App).where(App.id == server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")

@ -16,7 +16,7 @@ class AppSiteApi(Resource):
@marshal_with(fields.site_fields)
def get(self, app_model: App):
"""Retrieve app site info."""
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()

@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource):
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource):
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource):
if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position))
@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# get documents
@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields

@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource):
tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")

@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource):
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:

@ -44,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
if not app_model:
raise Forbidden("The app no longer exists.")
@ -54,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
@ -62,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
@ -213,15 +213,15 @@ def validate_dataset_token(view=None):
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
@ -293,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
end_user = (
db.session.query(EndUser)
.filter(
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
@ -320,7 +320,7 @@ class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
if not dataset:
raise NotFound("Dataset not found.")

@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from flask import request
from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@ -42,17 +43,17 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if end_user:
@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if session_id:
end_user = (
db.session.query(EndUser)
.filter(
end_user = db.session.scalar(
select(EndUser).where(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
if not end_user:
if not session_id:
@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if not end_user:
@ -224,6 +223,8 @@ def generate_session_id():
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
existing_count = db.session.scalar(
select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
)
if existing_count == 0:
return session_id

@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()

@ -3,6 +3,7 @@ from functools import wraps
from flask import request
from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first()
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()

@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner):
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
.where(
MessageAgentThought.message_id == self.message.id,
)
.count()
@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner):
Save agent thought
"""
updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
raise ValueError("agent thought not found")
@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

@ -72,7 +72,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")

@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).filter(Message.id == message.id).first()
message_result = db.session.query(Message).where(Message.id == message.id).first()
if message_result is None:
raise ValueError("Message not found")
db.session.close()

@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")

@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
"""
message = (
db.session.query(Message)
.filter(
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),

@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")

@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
@ -151,13 +151,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
introduction = self._get_conversation_introduction(application_generate_entity)
# get conversation name
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
query = application_generate_entity.query or "New conversation"
else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
if isinstance(query, int):
query = str(query)
query = query or "New conversation"
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation:
@ -259,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@ -272,7 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).filter(Message.id == message_id).first()
message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
raise MessageNotExistsError("Message not exists")

@ -26,7 +26,7 @@ class AnnotationReplyFeature:
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
if not annotation_setting:

@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:

@ -81,7 +81,7 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation:
return
@ -140,7 +140,7 @@ class MessageCycleManager:
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
if message_file and message_file.url is not None:
# get tool file id

@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
query = db.session.query(DocumentSegment).filter(
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)

@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):
provider_record = (
db.session.query(Provider)
.filter(
.where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):
provider_model_record = (
db.session.query(ProviderModel)
.filter(
.where(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):
return (
db.session.query(ProviderModelSetting)
.filter(
.where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):
return (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):
load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
.where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter(
.where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)

@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)

@ -15,7 +15,7 @@ def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()

@ -59,7 +59,7 @@ class IndexingRunner:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@ -119,12 +119,12 @@ class IndexingRunner:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@ -212,7 +212,7 @@ class IndexingRunner:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
@ -316,7 +316,7 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
@ -346,7 +346,7 @@ class IndexingRunner:
raise ValueError("no upload file found")
file_detail = (
db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
if file_detail:
@ -599,7 +599,7 @@ class IndexingRunner:
keyword.create(documents)
if dataset.indexing_technique != "high_quality":
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
@ -630,7 +630,7 @@ class IndexingRunner:
index_processor.load(dataset, chunk_documents, with_keywords=False)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),

@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler:
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)

@ -67,7 +67,7 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:

@ -89,7 +89,7 @@ class ApiModeration(Moderation):
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)

@ -120,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
@ -244,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance):
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
app = session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

@ -297,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
@ -703,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
return workflow_nodes

@ -44,14 +44,14 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).filter(App.id == app_id).first()
app = session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

@ -244,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id

@ -262,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id

@ -284,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id

@ -218,7 +218,7 @@ class OpsTraceManager:
"""
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@ -226,7 +226,7 @@ class OpsTraceManager:
return None
# decrypt_token
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")
@ -253,7 +253,7 @@ class OpsTraceManager:
if app_id is None:
return None
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if app is None:
return None
@ -293,18 +293,18 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).filter(Message.id == message_id).first()
message_data = db.session.query(Message).where(Message.id == message_id).first()
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
@ -331,7 +331,7 @@ class OpsTraceManager:
if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app_config:
raise ValueError("App not found")
app_config.tracing = json.dumps(
@ -349,7 +349,7 @@ class OpsTraceManager:
:param app_id: app id
:return:
"""
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")
if not app.tracing:

@ -3,6 +3,8 @@ from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.model import Message
@ -20,7 +22,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first()
return db.session.scalar(select(Message).where(Message.id == message_id))
@contextmanager

@ -235,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id

@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
raise ValueError("user not found")
@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first()
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
except Exception:
raise ValueError("app not found")

@ -275,7 +275,7 @@ class ProviderManager:
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.filter(
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
@ -367,7 +367,7 @@ class ProviderManager:
# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.filter(
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
@ -541,7 +541,7 @@ class ProviderManager:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.filter(
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,

@ -93,11 +93,11 @@ class Jieba(BaseKeyword):
documents = []
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).filter(
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment:
@ -214,7 +214,7 @@ class Jieba(BaseKeyword):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:

@ -127,7 +127,7 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
@ -145,7 +145,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
with Session(db.engine) as session:
return session.query(Dataset).filter(Dataset.id == dataset_id).first()
return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
@ -294,7 +294,7 @@ class RetrievalService:
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids))
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
@ -318,7 +318,7 @@ class RetrievalService:
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
if not child_chunk:
@ -326,7 +326,7 @@ class RetrievalService:
segment = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
@ -381,7 +381,7 @@ class RetrievalService:
segment = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",

@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:

@ -296,12 +296,22 @@ class TableStoreVector(BaseVector):
documents = []
for search_hit in search_response.search_hits:
if search_hit.score > score_threshold:
metadata = json.loads(search_hit.row[1][0][1])
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
metadata["score"] = search_hit.score
documents.append(
Document(
page_content=search_hit.row[1][2][1],
vector=json.loads(search_hit.row[1][3][1]),
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
vector=vector,
metadata=metadata,
)
)
@ -309,7 +319,7 @@ class TableStoreVector(BaseVector):
return documents
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
bool_query = tablestore.BoolQuery()
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
if document_ids_filter:
@ -329,11 +339,20 @@ class TableStoreVector(BaseVector):
documents = []
for search_hit in search_response.search_hits:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
documents.append(
Document(
page_content=search_hit.row[1][2][1],
vector=json.loads(search_hit.row[1][3][1]),
metadata=json.loads(search_hit.row[1][0][1]),
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
vector=vector,
metadata=metadata,
)
)
return documents

@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)

@ -47,7 +47,7 @@ class Vector:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:

@ -42,7 +42,7 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
output = {}
@ -63,7 +63,7 @@ class DatasetDocumentStore:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
.where(DocumentSegment.document_id == self._document_id)
.scalar()
)
@ -147,7 +147,7 @@ class DatasetDocumentStore:
segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
@ -230,7 +230,7 @@ class DatasetDocumentStore:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
)

@ -366,7 +366,7 @@ class NotionExtractor(BaseExtractor):
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",

@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.delete()
if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()
def retrieve(

@ -135,7 +135,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
@ -242,7 +242,7 @@ class DatasetRetrieval:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
@ -327,7 +327,7 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
results = []
if dataset.provider == "external":
@ -516,14 +516,14 @@ class DatasetRetrieval:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"])
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@ -533,7 +533,7 @@ class DatasetRetrieval:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
@ -541,13 +541,13 @@ class DatasetRetrieval:
)
db.session.commit()
else:
query = db.session.query(DocumentSegment).filter(
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
@ -600,7 +600,7 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
@ -685,7 +685,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
@ -862,7 +862,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter(
document_query = db.session.query(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@ -930,9 +930,9 @@ class DatasetRetrieval:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@ -958,7 +958,7 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:

@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)

@ -160,7 +160,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == id,
)
.first()
@ -184,7 +184,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.filter(
.where(
MessageFile.id == id,
)
.first()
@ -204,7 +204,7 @@ class ToolFileManager:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == tool_file_id,
)
.first()
@ -228,7 +228,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == tool_file_id,
)
.first()

@ -29,7 +29,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:
@ -57,7 +57,7 @@ class ToolLabelManager:
labels = (
db.session.query(ToolLabelBinding.label_name)
.filter(
.where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
@ -90,7 +90,7 @@ class ToolLabelManager:
provider_ids.append(controller.provider_id)
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

@ -198,7 +198,7 @@ class ToolManager:
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
@ -216,7 +216,7 @@ class ToolManager:
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
@ -229,7 +229,7 @@ class ToolManager:
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
@ -316,7 +316,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
@ -616,7 +616,7 @@ class ToolManager:
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
@ -664,7 +664,7 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers: list[dict[str, Any]] = [
@ -687,7 +687,7 @@ class ToolManager:
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
@ -731,7 +731,7 @@ class ToolManager:
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
@ -768,7 +768,7 @@ class ToolManager:
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
@ -793,7 +793,7 @@ class ToolManager:
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
@ -885,7 +885,7 @@ class ToolManager:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
@ -902,7 +902,7 @@ class ToolManager:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
@ -919,7 +919,7 @@ class ToolManager:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)

@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.filter(
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
if not dataset:

@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
if not dataset:
@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.filter(
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,

@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
"""
workflow: Workflow | None = (
db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)

@ -142,12 +142,12 @@ class WorkflowTool(Tool):
if not version:
workflow = (
db.session.query(Workflow)
.filter(Workflow.app_id == app_id, Workflow.version != "draft")
.where(Workflow.app_id == app_id, Workflow.version != "draft")
.order_by(Workflow.created_at.desc())
.first()
)
else:
workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
if not workflow:
raise ValueError("workflow not found or not published")
@ -158,7 +158,7 @@ class WorkflowTool(Tool):
"""
get the app by app id
"""
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("app not found")

@ -309,7 +309,7 @@ class AgentNode(BaseNode):
}
)
value = tool_value
if parameter.type == "model-selector":
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config

@ -228,7 +228,7 @@ class KnowledgeRetrievalNode(BaseNode):
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.filter(
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@ -242,8 +242,8 @@ class KnowledgeRetrievalNode(BaseNode):
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
@ -370,7 +370,7 @@ class KnowledgeRetrievalNode(BaseNode):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.filter(
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@ -415,7 +415,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(Document).filter(
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
@ -493,9 +493,9 @@ class KnowledgeRetrievalNode(BaseNode):
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
document_query = document_query.filter(and_(*filters))
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(BaseNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")

@ -5,6 +5,11 @@ set -e
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
echo "Running migrations"
flask upgrade-db
# Pure migration mode
if [[ "${MODE}" == "migration" ]]; then
echo "Migration completed, exiting normally"
exit 0
fi
fi
if [[ "${MODE}" == "worker" ]]; then

@ -22,7 +22,7 @@ def handle(sender, **kwargs):
document = (
db.session.query(Document)
.filter(
.where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)

@ -13,7 +13,7 @@ def handle(sender, **kwargs):
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:
@ -27,7 +27,7 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()

@ -15,7 +15,7 @@ def handle(sender, **kwargs):
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:
@ -29,7 +29,7 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()

@ -40,9 +40,9 @@ def load_user_from_request(request_from_flask_login):
if workspace_id:
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == workspace_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role == "owner")
.where(Tenant.id == workspace_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role == "owner")
.one_or_none()
)
if tenant_account_join:
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
@ -78,12 +78,12 @@ def load_user_from_request(request_from_flask_login):
server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code:
raise Unauthorized("Invalid Authorization token.")
app_mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not app_mcp_server:
raise NotFound("App MCP server not found.")
end_user = (
db.session.query(EndUser)
.filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
.where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
.first()
)
if not end_user:

@ -261,13 +261,11 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file = (
db.session.query(ToolFile)
.filter(
tool_file = db.session.scalar(
select(ToolFile).where(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id,
)
.first()
)
if tool_file is None:
@ -275,7 +273,7 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")

@ -25,6 +25,7 @@ class EmailType(Enum):
EMAIL_CODE_LOGIN = "email_code_login"
CHANGE_EMAIL_OLD = "change_email_old"
CHANGE_EMAIL_NEW = "change_email_new"
CHANGE_EMAIL_COMPLETED = "change_email_completed"
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
@ -344,6 +345,18 @@ def create_default_email_config() -> EmailI18nConfig:
branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html",
),
},
EmailType.CHANGE_EMAIL_COMPLETED: {
EmailLanguage.EN_US: EmailTemplate(
subject="Your login email has been changed",
template_path="change_mail_completed_template_en-US.html",
branded_template_path="without-brand/change_mail_completed_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的登录邮箱已更改",
template_path="change_mail_completed_template_zh-CN.html",
branded_template_path="without-brand/change_mail_completed_template_zh-CN.html",
),
},
EmailType.OWNER_TRANSFER_CONFIRM: {
EmailLanguage.EN_US: EmailTemplate(
subject="Verify Your Request to Transfer Workspace Ownership",

@ -3,6 +3,7 @@ from typing import Any
import requests
from flask_login import current_user
from sqlalchemy import select
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -61,16 +62,12 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
.first()
)
if data_source_binding:
data_source_binding.source_info = source_info
@ -101,16 +98,12 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
.first()
)
if data_source_binding:
data_source_binding.source_info = source_info
@ -129,18 +122,15 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
.first()
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)

@ -1,4 +1,5 @@
import hashlib
import os
from typing import Union
from Crypto.Cipher import AES
@ -17,7 +18,7 @@ def generate_key_pair(tenant_id: str) -> str:
pem_private = private_key.export_key()
pem_public = public_key.export_key()
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
filepath = os.path.join("privkeys", tenant_id, "private.pem")
storage.save(filepath, pem_private)
@ -47,7 +48,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
filepath = os.path.join("privkeys", tenant_id, "private.pem")
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
private_key = redis_client.get(cache_key)

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Optional, cast
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base
@ -119,7 +119,7 @@ class Account(UserMixin, Base):
@current_tenant.setter
def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
if ta:
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
@ -135,9 +135,9 @@ class Account(UserMixin, Base):
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
.one_or_none()
),
)
@ -161,11 +161,11 @@ class Account(UserMixin, Base):
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
.filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
if account_integrate:
return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none()
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']
@ -211,7 +211,7 @@ class Tenant(Base):
def get_accounts(self) -> list[Account]:
return (
db.session.query(Account)
.filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all()
)

@ -12,7 +12,7 @@ from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
@ -68,7 +68,7 @@ class Dataset(Base):
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
@ -95,7 +95,7 @@ class Dataset(Base):
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
)
@ -104,19 +104,19 @@ class Dataset(Base):
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
.filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
)
@property
def document_count(self):
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
.filter(
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
@ -129,7 +129,7 @@ class Dataset(Base):
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
.filter(
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
@ -142,13 +142,13 @@ class Dataset(Base):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.filter(Document.dataset_id == self.id)
.where(Document.dataset_id == self.id)
.scalar()
)
@property
def doc_form(self):
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document:
return document.doc_form
return None
@ -169,7 +169,7 @@ class Dataset(Base):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.filter(
.where(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
@ -185,14 +185,14 @@ class Dataset(Base):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
)
if not external_knowledge_binding:
return None
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis).where(
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
)
)
if not external_knowledge_api:
return None
@ -205,7 +205,7 @@ class Dataset(Base):
@property
def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
doc_metadata = [
{
@ -408,7 +408,7 @@ class Document(Base):
data_source_info_dict = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.filter(UploadFile.id == data_source_info_dict["upload_file_id"])
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
)
if file_detail:
@ -441,24 +441,24 @@ class Document(Base):
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
@property
def segment_count(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.filter(DocumentSegment.document_id == self.id)
.where(DocumentSegment.document_id == self.id)
.scalar()
)
@property
def uploader(self):
user = db.session.query(Account).filter(Account.id == self.created_by).first()
user = db.session.query(Account).where(Account.id == self.created_by).first()
return user.name if user else None
@property
@ -475,7 +475,7 @@ class Document(Base):
document_metadatas = (
db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.filter(
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
@ -687,26 +687,26 @@ class DocumentSegment(Base):
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def previous_segment(self):
return (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
.first()
return db.session.scalar(
select(DocumentSegment).where(
DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1
)
)
@property
def next_segment(self):
return (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
.first()
return db.session.scalar(
select(DocumentSegment).where(
DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1
)
)
@property
@ -717,7 +717,7 @@ class DocumentSegment(Base):
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
@ -734,7 +734,7 @@ class DocumentSegment(Base):
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
@ -825,15 +825,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
return db.session.query(Document).where(Document.id == self.document_id).first()
@property
def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(Base):
@ -1044,11 +1044,11 @@ class ExternalKnowledgeApis(Base):
def dataset_bindings(self):
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})

@ -113,13 +113,13 @@ class App(Base):
@property
def site(self):
site = db.session.query(Site).filter(Site.app_id == self.id).first()
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
def app_model_config(self):
if self.app_model_config_id:
return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None
@ -128,7 +128,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return None
@ -138,7 +138,7 @@ class App(Base):
@property
def tenant(self):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@property
@ -282,7 +282,7 @@ class App(Base):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.filter(
.where(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
@ -296,7 +296,7 @@ class App(Base):
@property
def author_name(self):
if self.created_by:
account = db.session.query(Account).filter(Account.id == self.created_by).first()
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
@ -338,7 +338,7 @@ class AppModelConfig(Base):
@property
def app(self):
app = db.session.query(App).filter(App.id == self.app_id).first()
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
@ -372,7 +372,7 @@ class AppModelConfig(Base):
@property
def annotation_reply_dict(self) -> dict:
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@ -577,7 +577,7 @@ class RecommendedApp(Base):
@property
def app(self):
app = db.session.query(App).filter(App.id == self.app_id).first()
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@ -601,12 +601,12 @@ class InstalledApp(Base):
@property
def app(self):
app = db.session.query(App).filter(App.id == self.app_id).first()
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def tenant(self):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@ -714,7 +714,7 @@ class Conversation(Base):
model_config["configs"] = override_model_configs
else:
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
)
if app_model_config:
model_config = app_model_config.to_dict()
@ -737,21 +737,21 @@ class Conversation(Base):
@property
def annotated(self):
return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
@property
def annotation(self):
return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first()
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
@property
def message_count(self):
return db.session.query(Message).filter(Message.conversation_id == self.id).count()
return db.session.query(Message).where(Message.conversation_id == self.id).count()
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.filter(
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
@ -761,7 +761,7 @@ class Conversation(Base):
dislike = (
db.session.query(MessageFeedback)
.filter(
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
@ -775,7 +775,7 @@ class Conversation(Base):
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.filter(
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
@ -785,7 +785,7 @@ class Conversation(Base):
dislike = (
db.session.query(MessageFeedback)
.filter(
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
@ -797,7 +797,7 @@ class Conversation(Base):
@property
def status_count(self):
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@ -824,19 +824,19 @@ class Conversation(Base):
def first_message(self):
return (
db.session.query(Message)
.filter(Message.conversation_id == self.id)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
)
@property
def app(self):
return db.session.query(App).filter(App.id == self.app_id).first()
return db.session.query(App).where(App.id == self.app_id).first()
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first()
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
if end_user:
return end_user.session_id
@ -845,7 +845,7 @@ class Conversation(Base):
@property
def from_account_name(self):
if self.from_account_id:
account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account:
return account.name
@ -1040,7 +1040,7 @@ class Message(Base):
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
)
return feedback
@ -1049,30 +1049,30 @@ class Message(Base):
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
)
return feedback
@property
def feedbacks(self):
feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all()
feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
return feedbacks
@property
def annotation(self):
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first()
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.id == annotation_history.annotation_id)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
)
return annotation
@ -1080,11 +1080,9 @@ class Message(Base):
@property
def app_model_config(self):
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
if conversation:
return (
db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first()
)
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return None
@ -1100,7 +1098,7 @@ class Message(Base):
def agent_thoughts(self):
return (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == self.id)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
@ -1113,8 +1111,8 @@ class Message(Base):
def message_files(self):
from factories import file_factory
message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
current_app = db.session.query(App).filter(App.id == self.app_id).first()
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@ -1178,7 +1176,7 @@ class Message(Base):
if self.workflow_run_id:
from .workflow import WorkflowRun
return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first()
return None
@ -1253,7 +1251,7 @@ class MessageFeedback(Base):
@property
def from_account(self):
account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self):
@ -1335,12 +1333,12 @@ class MessageAnnotation(Base):
@property
def account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first()
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first()
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@ -1371,14 +1369,14 @@ class AppAnnotationHitHistory(Base):
account = (
db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.filter(MessageAnnotation.id == self.annotation_id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first()
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@ -1404,7 +1402,7 @@ class AppAnnotationSetting(Base):
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == self.collection_binding_id)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
)
return collection_binding_detail
@ -1470,7 +1468,7 @@ class AppMCPServer(Base):
def generate_server_code(n):
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0:
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
result = generate_string(n)
return result
@ -1527,7 +1525,7 @@ class Site(Base):
def generate_code(n):
while True:
result = generate_string(n)
while db.session.query(Site).filter(Site.code == result).count() > 0:
while db.session.query(Site).where(Site.code == result).count() > 0:
result = generate_string(n)
return result
@ -1558,7 +1556,7 @@ class ApiToken(Base):
def generate_api_key(prefix, n):
while True:
result = prefix + generate_string(n)
if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0:
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
continue
return result

@ -153,11 +153,11 @@ class ApiToolProvider(Base):
def user(self) -> Account | None:
if not self.user_id:
return None
return db.session.query(Account).filter(Account.id == self.user_id).first()
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(Base):
@ -223,11 +223,11 @@ class WorkflowToolProvider(Base):
@property
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@ -235,7 +235,7 @@ class WorkflowToolProvider(Base):
@property
def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
return db.session.query(App).where(App.id == self.app_id).first()
class MCPToolProvider(Base):
@ -280,11 +280,11 @@ class MCPToolProvider(Base):
)
def load_user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict:

@ -26,7 +26,7 @@ class SavedMessage(Base):
@property
def message(self):
return db.session.query(Message).filter(Message.id == self.message_id).first()
return db.session.query(Message).where(Message.id == self.message_id).first()
class PinnedConversation(Base):

@ -343,7 +343,7 @@ class Workflow(Base):
return (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.count()
> 0
)
@ -549,12 +549,12 @@ class WorkflowRun(Base):
from models.model import Message
return (
db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
)
@property
def workflow(self):
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
def to_dict(self):
return {

@ -21,7 +21,7 @@ def clean_embedding_cache_task():
try:
embedding_ids = (
db.session.query(Embedding.id)
.filter(Embedding.created_at < thirty_days_ago)
.where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()

@ -36,7 +36,7 @@ def clean_messages():
# Main query with join and filter
messages = (
db.session.query(Message)
.filter(Message.created_at < plan_sandbox_clean_message_day)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
@ -66,25 +66,25 @@ def clean_messages():
plan = plan_cache.decode()
if plan == "sandbox":
# clean related message
db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message.id).delete(
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message.id).delete(
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).filter(MessageChain.message_id == message.id).delete(
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).delete(
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).filter(MessageFile.message_id == message.id).delete(
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).filter(SavedMessage.message_id == message.id).delete(
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).filter(Message.id == message.id).delete()
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green"))

@ -27,7 +27,7 @@ def clean_unused_datasets_task():
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@ -40,7 +40,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@ -55,7 +55,7 @@ def clean_unused_datasets_task():
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
.where(
Dataset.created_at < plan_sandbox_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
@ -72,7 +72,7 @@ def clean_unused_datasets_task():
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
.where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
@ -80,7 +80,7 @@ def clean_unused_datasets_task():
# add auto disable log
documents = (
db.session.query(Document)
.filter(
.where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
@ -111,7 +111,7 @@ def clean_unused_datasets_task():
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@ -124,7 +124,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.filter(
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@ -139,7 +139,7 @@ def clean_unused_datasets_task():
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
.where(
Dataset.created_at < plan_pro_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
@ -155,7 +155,7 @@ def clean_unused_datasets_task():
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
.where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:

@ -20,7 +20,7 @@ def create_tidb_serverless_task():
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break

@ -30,7 +30,7 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
@ -45,7 +45,7 @@ def mail_clean_document_notify_task():
if plan != "sandbox":
knowledge_details = []
# check tenant
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
continue
# check current owner
@ -54,7 +54,7 @@ def mail_clean_document_notify_task():
)
if not current_owner_join:
continue
account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first()
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
if not account:
continue
@ -67,7 +67,7 @@ def mail_clean_document_notify_task():
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

@ -17,7 +17,7 @@ def update_tidb_serverless_status_task():
# check the number of idle tidb serverless
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0:

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save