From 7c04b4a38efc20d5e87bbb7a33ad13333863b9f9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 23 Jul 2025 04:11:51 +0900 Subject: [PATCH] more .first() --- api/core/ops/utils.py | 4 ++-- api/factories/file_factory.py | 8 ++++---- api/libs/oauth_data_source.py | 32 ++++++++++++++++---------------- api/models/account.py | 6 ++++-- api/models/dataset.py | 30 +++++++++++++++--------------- 5 files changed, 41 insertions(+), 39 deletions(-) diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 36d060afd2..39308e883b 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from datetime import datetime from typing import Optional, Union from urllib.parse import urlparse - +from sqlalchemy import select from extensions.ext_database import db from models.model import Message @@ -20,7 +20,7 @@ def filter_none_values(data: dict): def get_message_data(message_id: str): - return db.session.query(Message).filter(Message.id == message_id).first() + return db.session.scalars(select(Message).filter(Message.id == message_id).limit(1)).first() @contextmanager diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index c974dbb700..861b067ed2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -261,14 +261,14 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - tool_file = ( - db.session.query(ToolFile) + tool_file = db.session.scalars( + select(ToolFile) .filter( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, ) - .first() - ) + .limit(1) + ).first() if tool_file is None: raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 78f827584c..0be9669acb 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -3,7 +3,7 @@ from typing import Any import requests from flask_login import current_user - +from sqlalchemy import select, and_ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -61,17 +61,17 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) + data_source_binding = db.session.scalars( + select(DataSourceOauthBinding) .filter( - db.and_( + and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.access_token == access_token, ) ) - .first() - ) + .limit(1) + ).first() if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False @@ -101,17 +101,17 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) + data_source_binding = db.session.scalars( + select(DataSourceOauthBinding) .filter( - db.and_( + and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.access_token == access_token, ) ) - .first() - ) + .limit(1) + ).first() if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False @@ -129,18 +129,18 @@ class NotionOAuth(OAuthDataSource): def sync_data_source(self, binding_id: str): # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) + data_source_binding = db.session.scalars( + select(DataSourceOauthBinding) .filter( - db.and_( + and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.disabled == False, ) ) - .first() - ) + .limit(1) + ).first() if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) diff --git a/api/models/account.py b/api/models/account.py index 01d1625dbd..2f5df073cb 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -119,7 +119,9 @@ class Account(UserMixin, Base): @current_tenant.setter def current_tenant(self, tenant: "Tenant"): - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() + ta = db.session.scalars( + select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1) + ).first() if ta: self.role = TenantAccountRole(ta.role) self._current_tenant = tenant diff --git a/api/models/dataset.py b/api/models/dataset.py index d5a13efb90..1c1fc1c95b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,7 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -189,11 +189,11 @@ class Dataset(Base): ) if not external_knowledge_binding: return None - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis) + external_knowledge_api = db.session.scalars( + select(ExternalKnowledgeApis) .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) - .first() - ) + .limit(1) + ).first() if not external_knowledge_api: return None return { @@ -687,27 +687,27 @@ class DocumentSegment(Base): @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + return db.session.scalars(select(Dataset).filter(Dataset.id == self.dataset_id).limit(1)).first() @property def document(self): - return db.session.query(Document).filter(Document.id == self.document_id).first() + return db.session.scalars(select(Document).filter(Document.id == self.document_id).limit(1)).first() @property def previous_segment(self): - return ( - db.session.query(DocumentSegment) + return db.session.scalars( + select(DocumentSegment) .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) - .first() - ) + .limit(1) + ).first() @property def next_segment(self): - return ( - db.session.query(DocumentSegment) + return db.session.scalars( + select(DocumentSegment) .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) - .first() - ) + .limit(1) + ).first() @property def child_chunks(self):