migrate 2 files

pull/22801/head
Asuka Minato 10 months ago committed by -LAN-
parent 0731db8c22
commit 8861b25597
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config
@ -42,18 +43,18 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()
# get site from db and check if it is normal # get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
if not site: if not site:
raise NotFound() raise NotFound()
# get app from db and check if it is normal and enable_site # get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first() app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
if user_id: if user_id:
end_user = ( end_user = db.session.scalars(
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
) ).first()
if end_user: if end_user:
pass pass
@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type: if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.") raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
if not site: if not site:
raise NotFound() raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first() app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
@ -140,17 +141,17 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None end_user = None
if end_user_id: if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
if session_id: if session_id:
end_user = ( end_user = db.session.scalars(
db.session.query(EndUser) select(EndUser)
.filter( .filter(
EndUser.session_id == session_id, EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
) )
.first() .limit(1)
) ).first()
if not end_user: if not end_user:
if not session_id: if not session_id:
raise NotFound("Missing session_id for existing web user.") raise NotFound("Missing session_id for existing web user.")
@ -187,9 +188,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id") user_id = token_decoded.get("user_id")
end_user = None end_user = None
if user_id: if user_id:
end_user = ( end_user = db.session.scalars(
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
) ).first()
if not end_user: if not end_user:
end_user = EndUser( end_user = EndUser(
@ -224,6 +225,8 @@ def generate_session_id():
""" """
while True: while True:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() existing_count = db.session.scalar(
select(func.count()).select_from(EndUser).filter(EndUser.session_id == session_id)
)
if existing_count == 0: if existing_count == 0:
return session_id return session_id

@ -3,6 +3,7 @@ from functools import wraps
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_id = decoded.get("app_id") app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first() app_model = db.session.scalars(select(App).filter(App.id == app_id).limit(1)).first()
site = db.session.query(Site).filter(Site.code == app_code).first() site = db.session.scalars(select(Site).filter(Site.code == app_code).limit(1)).first()
if not app_model: if not app_model:
raise NotFound() raise NotFound()
if not app_code or not site: if not app_code or not site:
@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False: if app_model.enable_site is False:
raise BadRequest("Site is disabled.") raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id") end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
if not end_user: if not end_user:
raise NotFound() raise NotFound()

Loading…
Cancel
Save