pull/22801/head
Asuka Minato 10 months ago committed by -LAN-
parent 9bd823fda1
commit 111e0f4cf2
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -43,18 +43,18 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()
# get site from db and check if it is normal # get site from db and check if it is normal
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first() site = db.session.scalar(select(Site).filter(Site.code == app_code, Site.status == "normal"))
if not site: if not site:
raise NotFound() raise NotFound()
# get app from db and check if it is normal and enable_site # get app from db and check if it is normal and enable_site
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first() app_model = db.session.scalar(select(App).filter(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
if user_id: if user_id:
end_user = db.session.scalars( end_user = db.session.scalar(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1) select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
).first() )
if end_user: if end_user:
pass pass
@ -122,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type: if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.") raise Unauthorized("Missing auth_type in the token.")
site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first() site = db.session.scalar(select(Site).filter(Site.code == app_code, Site.status == "normal"))
if not site: if not site:
raise NotFound() raise NotFound()
app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first() app_model = db.session.scalar(select(App).filter(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
@ -141,17 +141,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None end_user = None
if end_user_id: if end_user_id:
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first() end_user = db.session.scalar(select(EndUser).filter(EndUser.id == end_user_id))
if session_id: if session_id:
end_user = db.session.scalars( end_user = db.session.scalar(
select(EndUser) select(EndUser).filter(
.filter(
EndUser.session_id == session_id, EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
) )
.limit(1) )
).first()
if not end_user: if not end_user:
if not session_id: if not session_id:
raise NotFound("Missing session_id for existing web user.") raise NotFound("Missing session_id for existing web user.")
@ -188,9 +186,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id") user_id = token_decoded.get("user_id")
end_user = None end_user = None
if user_id: if user_id:
end_user = db.session.scalars( end_user = db.session.scalar(
select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1) select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
).first() )
if not end_user: if not end_user:
end_user = EndUser( end_user = EndUser(

@ -49,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_id = decoded.get("app_id") app_id = decoded.get("app_id")
app_model = db.session.scalars(select(App).filter(App.id == app_id).limit(1)).first() app_model = db.session.scalar(select(App).filter(App.id == app_id))
site = db.session.scalars(select(Site).filter(Site.code == app_code).limit(1)).first() site = db.session.scalar(select(Site).filter(Site.code == app_code))
if not app_model: if not app_model:
raise NotFound() raise NotFound()
if not app_code or not site: if not app_code or not site:
@ -58,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False: if app_model.enable_site is False:
raise BadRequest("Site is disabled.") raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id") end_user_id = decoded.get("end_user_id")
end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first() end_user = db.session.scalar(select(EndUser).filter(EndUser.id == end_user_id))
if not end_user: if not end_user:
raise NotFound() raise NotFound()

@ -22,7 +22,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str): def get_message_data(message_id: str):
return db.session.scalars(select(Message).filter(Message.id == message_id).limit(1)).first() return db.session.scalar(select(Message).filter(Message.id == message_id))
@contextmanager @contextmanager

@ -261,14 +261,12 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
strict_type_validation: bool = False, strict_type_validation: bool = False,
) -> File: ) -> File:
tool_file = db.session.scalars( tool_file = db.session.scalar(
select(ToolFile) select(ToolFile).filter(
.filter(
ToolFile.id == mapping.get("tool_file_id"), ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id, ToolFile.tenant_id == tenant_id,
) )
.limit(1) )
).first()
if tool_file is None: if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")

@ -62,17 +62,15 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = db.session.scalars( data_source_binding = db.session.scalar(
select(DataSourceOauthBinding) select(DataSourceOauthBinding).filter(
.filter(
and_( and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
) )
.limit(1) )
).first()
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -102,17 +100,15 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = db.session.scalars( data_source_binding = db.session.scalar(
select(DataSourceOauthBinding) select(DataSourceOauthBinding).filter(
.filter(
and_( and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
) )
.limit(1) )
).first()
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -130,9 +126,8 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str): def sync_data_source(self, binding_id: str):
# save data source binding # save data source binding
data_source_binding = db.session.scalars( data_source_binding = db.session.scalar(
select(DataSourceOauthBinding) select(DataSourceOauthBinding).filter(
.filter(
and_( and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
@ -140,8 +135,7 @@ class NotionOAuth(OAuthDataSource):
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
) )
.limit(1) )
).first()
if data_source_binding: if data_source_binding:
# get all authorized pages # get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token) pages = self.get_authorized_pages(data_source_binding.access_token)

Loading…
Cancel
Save