diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 10c3cdcf0e..e45c86f104 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta from flask import request from flask_restful import Resource +from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -42,18 +43,18 @@ class PassportResource(Resource): raise WebAppAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first() if not site: raise NotFound() # get app from db and check if it is normal and enable_site - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first() if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() if user_id: - end_user = ( - db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() - ) + end_user = db.session.scalars( + select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1) + ).first() if end_user: pass @@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first() if not site: raise NotFound() - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first() if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() @@ -140,17 +141,17 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: end_user = None if end_user_id: - end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first() if session_id: - end_user = ( - db.session.query(EndUser) + end_user = db.session.scalars( + select(EndUser) .filter( EndUser.session_id == session_id, EndUser.tenant_id == app_model.tenant_id, EndUser.app_id == app_model.id, ) - .first() - ) + .limit(1) + ).first() if not end_user: if not session_id: raise NotFound("Missing session_id for existing web user.") @@ -187,9 +188,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): user_id = token_decoded.get("user_id") end_user = None if user_id: - end_user = ( - db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() - ) + end_user = db.session.scalars( + select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1) + ).first() if not end_user: end_user = EndUser( @@ -224,6 +225,8 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() + existing_count = db.session.scalar( + select(func.count()).select_from(EndUser).filter(EndUser.session_id == session_id) + ) if existing_count == 0: return session_id diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 154bddfc5c..83d51b2f59 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -3,6 +3,7 @@ from functools import wraps from flask import request from flask_restful import Resource +from sqlalchemy import select from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError @@ -48,8 +49,8 @@ def decode_jwt_token(): decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") - app_model = db.session.query(App).filter(App.id == app_id).first() - site = db.session.query(Site).filter(Site.code == app_code).first() + app_model = db.session.scalars(select(App).filter(App.id == app_id).limit(1)).first() + site = db.session.scalars(select(Site).filter(Site.code == app_code).limit(1)).first() if not app_model: raise NotFound() if not app_code or not site: @@ -57,7 +58,7 @@ def decode_jwt_token(): if app_model.enable_site is False: raise BadRequest("Site is disabled.") end_user_id = decoded.get("end_user_id") - end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first() if not end_user: raise NotFound()