more .first()

pull/22801/head
Asuka Minato 10 months ago committed by -LAN-
parent 8861b25597
commit 7c04b4a38e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -2,7 +2,7 @@ from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message from models.model import Message
@ -20,7 +20,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str): def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first() return db.session.scalars(select(Message).filter(Message.id == message_id).limit(1)).first()
@contextmanager @contextmanager

@ -261,14 +261,14 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
strict_type_validation: bool = False, strict_type_validation: bool = False,
) -> File: ) -> File:
tool_file = ( tool_file = db.session.scalars(
db.session.query(ToolFile) select(ToolFile)
.filter( .filter(
ToolFile.id == mapping.get("tool_file_id"), ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id, ToolFile.tenant_id == tenant_id,
) )
.first() .limit(1)
) ).first()
if tool_file is None: if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")

@ -3,7 +3,7 @@ from typing import Any
import requests import requests
from flask_login import current_user from flask_login import current_user
from sqlalchemy import select, and_
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding from models.source import DataSourceOauthBinding
@ -61,17 +61,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = ( data_source_binding = db.session.scalars(
db.session.query(DataSourceOauthBinding) select(DataSourceOauthBinding)
.filter( .filter(
db.and_( and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
) )
.first() .limit(1)
) ).first()
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -101,17 +101,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = ( data_source_binding = db.session.scalars(
db.session.query(DataSourceOauthBinding) select(DataSourceOauthBinding)
.filter( .filter(
db.and_( and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
) )
.first() .limit(1)
) ).first()
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -129,18 +129,18 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str): def sync_data_source(self, binding_id: str):
# save data source binding # save data source binding
data_source_binding = ( data_source_binding = db.session.scalars(
db.session.query(DataSourceOauthBinding) select(DataSourceOauthBinding)
.filter( .filter(
db.and_( and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
) )
.first() .limit(1)
) ).first()
if data_source_binding: if data_source_binding:
# get all authorized pages # get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token) pages = self.get_authorized_pages(data_source_binding.access_token)

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Optional, cast from typing import Optional, cast
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base from models.base import Base
@ -119,7 +119,9 @@ class Account(UserMixin, Base):
@current_tenant.setter @current_tenant.setter
def current_tenant(self, tenant: "Tenant"): def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() ta = db.session.scalars(
select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1)
).first()
if ta: if ta:
self.role = TenantAccountRole(ta.role) self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant self._current_tenant = tenant

@ -12,7 +12,7 @@ from datetime import datetime
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -189,11 +189,11 @@ class Dataset(Base):
) )
if not external_knowledge_binding: if not external_knowledge_binding:
return None return None
external_knowledge_api = ( external_knowledge_api = db.session.scalars(
db.session.query(ExternalKnowledgeApis) select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.first() .limit(1)
) ).first()
if not external_knowledge_api: if not external_knowledge_api:
return None return None
return { return {
@ -687,27 +687,27 @@ class DocumentSegment(Base):
@property @property
def dataset(self): def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() return db.session.scalars(select(Dataset).filter(Dataset.id == self.dataset_id).limit(1)).first()
@property @property
def document(self): def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first() return db.session.scalars(select(Document).filter(Document.id == self.document_id).limit(1)).first()
@property @property
def previous_segment(self): def previous_segment(self):
return ( return db.session.scalars(
db.session.query(DocumentSegment) select(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
.first() .limit(1)
) ).first()
@property @property
def next_segment(self): def next_segment(self):
return ( return db.session.scalars(
db.session.query(DocumentSegment) select(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
.first() .limit(1)
) ).first()
@property @property
def child_chunks(self): def child_chunks(self):

Loading…
Cancel
Save