|
|
|
|
@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
|
|
|
|
|
|
|
|
|
|
from flask import request
|
|
|
|
|
from flask_restful import Resource
|
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
|
from werkzeug.exceptions import NotFound, Unauthorized
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
@ -42,18 +43,18 @@ class PassportResource(Resource):
|
|
|
|
|
raise WebAppAuthRequiredError()
|
|
|
|
|
|
|
|
|
|
# get site from db and check if it is normal
|
|
|
|
|
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
|
|
|
|
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
|
|
|
|
|
if not site:
|
|
|
|
|
raise NotFound()
|
|
|
|
|
# get app from db and check if it is normal and enable_site
|
|
|
|
|
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
|
|
|
|
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
|
|
|
|
|
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
|
|
|
|
raise NotFound()
|
|
|
|
|
|
|
|
|
|
if user_id:
|
|
|
|
|
end_user = (
|
|
|
|
|
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
|
|
|
|
)
|
|
|
|
|
end_user = db.session.scalars(
|
|
|
|
|
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if end_user:
|
|
|
|
|
pass
|
|
|
|
|
@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|
|
|
|
if not user_auth_type:
|
|
|
|
|
raise Unauthorized("Missing auth_type in the token.")
|
|
|
|
|
|
|
|
|
|
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
|
|
|
|
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
|
|
|
|
|
if not site:
|
|
|
|
|
raise NotFound()
|
|
|
|
|
|
|
|
|
|
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
|
|
|
|
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
|
|
|
|
|
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
|
|
|
|
raise NotFound()
|
|
|
|
|
|
|
|
|
|
@ -140,17 +141,17 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|
|
|
|
|
|
|
|
|
end_user = None
|
|
|
|
|
if end_user_id:
|
|
|
|
|
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
|
|
|
|
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
|
|
|
|
|
if session_id:
|
|
|
|
|
end_user = (
|
|
|
|
|
db.session.query(EndUser)
|
|
|
|
|
end_user = db.session.scalars(
|
|
|
|
|
select(EndUser)
|
|
|
|
|
.filter(
|
|
|
|
|
EndUser.session_id == session_id,
|
|
|
|
|
EndUser.tenant_id == app_model.tenant_id,
|
|
|
|
|
EndUser.app_id == app_model.id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
.limit(1)
|
|
|
|
|
).first()
|
|
|
|
|
if not end_user:
|
|
|
|
|
if not session_id:
|
|
|
|
|
raise NotFound("Missing session_id for existing web user.")
|
|
|
|
|
@ -187,9 +188,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
|
|
|
|
|
user_id = token_decoded.get("user_id")
|
|
|
|
|
end_user = None
|
|
|
|
|
if user_id:
|
|
|
|
|
end_user = (
|
|
|
|
|
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
|
|
|
|
)
|
|
|
|
|
end_user = db.session.scalars(
|
|
|
|
|
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if not end_user:
|
|
|
|
|
end_user = EndUser(
|
|
|
|
|
@ -224,6 +225,8 @@ def generate_session_id():
|
|
|
|
|
"""
|
|
|
|
|
while True:
|
|
|
|
|
session_id = str(uuid.uuid4())
|
|
|
|
|
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
|
|
|
|
|
existing_count = db.session.scalar(
|
|
|
|
|
select(func.count()).select_from(EndUser).filter(EndUser.session_id == session_id)
|
|
|
|
|
)
|
|
|
|
|
if existing_count == 0:
|
|
|
|
|
return session_id
|
|
|
|
|
|