From 111e0f4cf293770186156b105a35d568c9c09e05 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 23 Jul 2025 13:40:34 +0900 Subject: [PATCH] simpler --- api/controllers/web/passport.py | 30 ++++++++++++++---------------- api/controllers/web/wraps.py | 6 +++--- api/core/ops/utils.py | 2 +- api/factories/file_factory.py | 8 +++----- api/libs/oauth_data_source.py | 24 +++++++++--------------- 5 files changed, 30 insertions(+), 40 deletions(-) diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index e45c86f104..532863c3ba 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -43,18 +43,18 @@ class PassportResource(Resource): raise WebAppAuthRequiredError() # get site from db and check if it is normal - site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first() + site = db.session.scalar(select(Site).filter(Site.code == app_code, Site.status == "normal")) if not site: raise NotFound() # get app from db and check if it is normal and enable_site - app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first() + app_model = db.session.scalar(select(App).filter(App.id == site.app_id)) if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() if user_id: - end_user = db.session.scalars( - select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1) - ).first() + end_user = db.session.scalar( + select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id) + ) if end_user: pass @@ -122,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") - site = db.session.scalars(select(Site).filter(Site.code == app_code, Site.status == "normal").limit(1)).first() + site = db.session.scalar(select(Site).filter(Site.code == app_code, Site.status == "normal")) if not site: raise NotFound() - app_model = db.session.scalars(select(App).filter(App.id == site.app_id).limit(1)).first() + app_model = db.session.scalar(select(App).filter(App.id == site.app_id)) if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() @@ -141,17 +141,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: end_user = None if end_user_id: - end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first() + end_user = db.session.scalar(select(EndUser).filter(EndUser.id == end_user_id)) if session_id: - end_user = db.session.scalars( - select(EndUser) - .filter( + end_user = db.session.scalar( + select(EndUser).filter( EndUser.session_id == session_id, EndUser.tenant_id == app_model.tenant_id, EndUser.app_id == app_model.id, ) - .limit(1) - ).first() + ) if not end_user: if not session_id: raise NotFound("Missing session_id for existing web user.") @@ -188,9 +186,9 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): user_id = token_decoded.get("user_id") end_user = None if user_id: - end_user = db.session.scalars( - select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).limit(1) - ).first() + end_user = db.session.scalar( + select(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id) + ) if not end_user: end_user = EndUser( diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 83d51b2f59..7c0c27def8 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -49,8 +49,8 @@ def decode_jwt_token(): decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") - app_model = db.session.scalars(select(App).filter(App.id == app_id).limit(1)).first() - site = db.session.scalars(select(Site).filter(Site.code == app_code).limit(1)).first() + app_model = db.session.scalar(select(App).filter(App.id == app_id)) + site = db.session.scalar(select(Site).filter(Site.code == app_code)) if not app_model: raise NotFound() if not app_code or not site: @@ -58,7 +58,7 @@ def decode_jwt_token(): if app_model.enable_site is False: raise BadRequest("Site is disabled.") end_user_id = decoded.get("end_user_id") - end_user = db.session.scalars(select(EndUser).filter(EndUser.id == end_user_id).limit(1)).first() + end_user = db.session.scalar(select(EndUser).filter(EndUser.id == end_user_id)) if not end_user: raise NotFound() diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 1743335e40..aaf4e3a637 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -22,7 +22,7 @@ def filter_none_values(data: dict): def get_message_data(message_id: str): - return db.session.scalars(select(Message).filter(Message.id == message_id).limit(1)).first() + return db.session.scalar(select(Message).filter(Message.id == message_id)) @contextmanager diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 861b067ed2..d14497186a 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -261,14 +261,12 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - tool_file = db.session.scalars( - select(ToolFile) - .filter( + tool_file = db.session.scalar( + select(ToolFile).filter( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, ) - .limit(1) - ).first() + ) if tool_file is None: raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 88b43da118..0ceeff8bc0 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -62,17 +62,15 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = db.session.scalars( - select(DataSourceOauthBinding) - .filter( + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).filter( and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.access_token == access_token, ) ) - .limit(1) - ).first() + ) if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False @@ -102,17 +100,15 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = db.session.scalars( - select(DataSourceOauthBinding) - .filter( + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).filter( and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.access_token == access_token, ) ) - .limit(1) - ).first() + ) if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False @@ -130,9 +126,8 @@ class NotionOAuth(OAuthDataSource): def sync_data_source(self, binding_id: str): # save data source binding - data_source_binding = db.session.scalars( - select(DataSourceOauthBinding) - .filter( + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).filter( and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", @@ -140,8 +135,7 @@ class NotionOAuth(OAuthDataSource): DataSourceOauthBinding.disabled == False, ) ) - .limit(1) - ).first() + ) if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token)