migrate 2 files

pull/22801/head
Asuka Minato 7 months ago committed by -LAN-
parent 0731db8c22
commit 8861b25597
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from flask import request
from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@ -42,18 +43,18 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
)
end_user = db.session.scalars(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
).first()
if end_user:
pass
@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
if not site:
raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
@ -140,17 +141,17 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
if session_id:
end_user = (
db.session.query(EndUser)
end_user = db.session.scalars(
select(EndUser)
.filter(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
.limit(1)
).first()
if not end_user:
if not session_id:
raise NotFound("Missing session_id for existing web user.")
@ -187,9 +188,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
)
end_user = db.session.scalars(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
).first()
if not end_user:
end_user = EndUser(
@ -224,6 +225,8 @@ def generate_session_id():
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
existing_count = db.session.scalar(
select(func.count()).select_from(EndUser).filter(EndUser.session_id == session_id)
)
if existing_count == 0:
return session_id

@ -3,6 +3,7 @@ from functools import wraps
from flask import request
from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first()
app_model = db.session.scalars(select(App).filter(App.id == app_id).limit(1)).first()
site = db.session.scalars(select(Site).filter(Site.code == app_code).limit(1)).first()
if not app_model:
raise NotFound()
if not app_code or not site:
@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
if not end_user:
raise NotFound()

Loading…
Cancel
Save