more .first()

pull/22801/head
Asuka Minato 7 months ago committed by -LAN-
parent 8861b25597
commit 7c04b4a38e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -2,7 +2,7 @@ from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.model import Message
@ -20,7 +20,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first()
return db.session.scalars(select(Message).filter(Message.id == message_id).limit(1)).first()
@contextmanager

@ -261,14 +261,14 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file = (
db.session.query(ToolFile)
tool_file = db.session.scalars(
select(ToolFile)
.filter(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id,
)
.first()
)
.limit(1)
).first()
if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")

@ -3,7 +3,7 @@ from typing import Any
import requests
from flask_login import current_user
from sqlalchemy import select, and_
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@ -61,17 +61,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = (
db.session.query(DataSourceOauthBinding)
data_source_binding = db.session.scalars(
select(DataSourceOauthBinding)
.filter(
db.and_(
and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
.first()
)
.limit(1)
).first()
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -101,17 +101,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = (
db.session.query(DataSourceOauthBinding)
data_source_binding = db.session.scalars(
select(DataSourceOauthBinding)
.filter(
db.and_(
and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
.first()
)
.limit(1)
).first()
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -129,18 +129,18 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = (
db.session.query(DataSourceOauthBinding)
data_source_binding = db.session.scalars(
select(DataSourceOauthBinding)
.filter(
db.and_(
and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
.first()
)
.limit(1)
).first()
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Optional, cast
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base
@ -119,7 +119,9 @@ class Account(UserMixin, Base):
@current_tenant.setter
def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
ta = db.session.scalars(
select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1)
).first()
if ta:
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant

@ -12,7 +12,7 @@ from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
@ -189,11 +189,11 @@ class Dataset(Base):
)
if not external_knowledge_binding:
return None
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
external_knowledge_api = db.session.scalars(
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.first()
)
.limit(1)
).first()
if not external_knowledge_api:
return None
return {
@ -687,27 +687,27 @@ class DocumentSegment(Base):
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
return db.session.scalars(select(Dataset).filter(Dataset.id == self.dataset_id).limit(1)).first()
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
return db.session.scalars(select(Document).filter(Document.id == self.document_id).limit(1)).first()
@property
def previous_segment(self):
return (
db.session.query(DocumentSegment)
return db.session.scalars(
select(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
.first()
)
.limit(1)
).first()
@property
def next_segment(self):
return (
db.session.query(DocumentSegment)
return db.session.scalars(
select(DocumentSegment)
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
.first()
)
.limit(1)
).first()
@property
def child_chunks(self):

Loading…
Cancel
Save