pull/22801/head
Asuka Minato 7 months ago committed by -LAN-
parent 9bd823fda1
commit 111e0f4cf2
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -43,18 +43,18 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
site = db.session.scalar(select(Site).filter(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
app_model = db.session.scalar(select(App).filter(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
if user_id:
end_user = db.session.scalars(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
).first()
end_user = db.session.scalar(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if end_user:
pass
@ -122,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first()
site = db.session.scalar(select(Site).filter(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first()
app_model = db.session.scalar(select(App).filter(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
@ -141,17 +141,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None
if end_user_id:
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
end_user = db.session.scalar(select(EndUser).filter(EndUser.id == end_user_id))
if session_id:
end_user = db.session.scalars(
select(EndUser)
.filter(
end_user = db.session.scalar(
select(EndUser).filter(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.limit(1)
).first()
)
if not end_user:
if not session_id:
raise NotFound("Missing session_id for existing web user.")
@ -188,9 +186,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = db.session.scalars(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1)
).first()
end_user = db.session.scalar(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if not end_user:
end_user = EndUser(

@ -49,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.scalars(select(App).filter(App.id == app_id).limit(1)).first()
site = db.session.scalars(select(Site).filter(Site.code == app_code).limit(1)).first()
app_model = db.session.scalar(select(App).filter(App.id == app_id))
site = db.session.scalar(select(Site).filter(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
@ -58,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first()
end_user = db.session.scalar(select(EndUser).filter(EndUser.id == end_user_id))
if not end_user:
raise NotFound()

@ -22,7 +22,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str):
return db.session.scalars(select(Message).filter(Message.id == message_id).limit(1)).first()
return db.session.scalar(select(Message).filter(Message.id == message_id))
@contextmanager

@ -261,14 +261,12 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file = db.session.scalars(
select(ToolFile)
.filter(
tool_file = db.session.scalar(
select(ToolFile).filter(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id,
)
.limit(1)
).first()
)
if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")

@ -62,17 +62,15 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = db.session.scalars(
select(DataSourceOauthBinding)
.filter(
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).filter(
and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
.limit(1)
).first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -102,17 +100,15 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = db.session.scalars(
select(DataSourceOauthBinding)
.filter(
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).filter(
and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
.limit(1)
).first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -130,9 +126,8 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = db.session.scalars(
select(DataSourceOauthBinding)
.filter(
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).filter(
and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@ -140,8 +135,7 @@ class NotionOAuth(OAuthDataSource):
DataSourceOauthBinding.disabled == False,
)
)
.limit(1)
).first()
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)

Loading…
Cancel
Save