fix: 修复嵌套子查询异常

main
chenzhirong 3 weeks ago
parent 6d94c43685
commit f2da252a59

@ -3,144 +3,136 @@ import inspect
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, ParamSpec, TypeVar, Callable from typing import Any, ParamSpec, TypeVar, Callable
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ from sqlalchemy import Executable, Result, and_, Select, Delete, Update
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.sql.selectable import Subquery
from common.constant import Constant from common.constant import Constant
from common.global_enums import IsDelete from common.global_enums import IsDelete
from config import settings from config import settings
from entity.base_entity import DbBaseModel
P = ParamSpec('P') P = ParamSpec('P')
T = TypeVar('T') T = TypeVar('T')
engine = create_async_engine( engine = create_async_engine(settings.database_url, echo=False, # 打印SQL日志生产环境建议关闭
settings.database_url, pool_size=10, # 连接池大小
echo=False, # 打印SQL日志生产环境建议关闭 max_overflow=20, # 最大溢出连接数
pool_size=10, # 连接池大小 pool_recycle=3600, # 连接回收时间解决MySQL超时断开问题【4†source】【5†source】
max_overflow=20, # 最大溢出连接数 )
pool_recycle=3600, # 连接回收时间解决MySQL超时断开问题【4†source】【5†source】
)
# 创建异步会话工厂 # 创建异步会话工厂
class EnhanceAsyncSession(AsyncSession): class EnhanceAsyncSession(AsyncSession):
async def scalar(self, statement: Executable,
params=None,
*,
execution_options=None,
bind_arguments=None,
**kw: Any, ):
def _add_logical_delete_condition(self, statement: Select) -> Select:
"""
Select 语句添加逻辑删除条件
支持递归处理 fastapi-pagination 生成的子查询包装
"""
# 特殊处理:如果当前查询只有一个子查询,直接修改其内部元素
# 这样可以避免产生重复的别名
if len(statement.froms) == 1 and isinstance(statement.froms[0], Subquery):
subquery = statement.froms[0]
# 递归处理子查询内部
processed_inner = self._add_logical_delete_condition(subquery.element)
# 如果内部有修改,直接替换 subquery 的 element
if processed_inner is not subquery.element:
subquery.element = processed_inner
return statement
# 处理当前层的表
delete_condition = None
delete_field = Constant.LOGICAL_DELETE_FIELD
for from_obj in statement.froms:
# 跳过子查询(因为会通过上面的逻辑递归处理)
if isinstance(from_obj, Subquery):
continue
# 只处理 Table 对象
if hasattr(from_obj, 'columns') and delete_field in from_obj.columns:
# 使用表对象上的列
condition = from_obj.columns[delete_field] == IsDelete.NO_DELETE
if delete_condition is None:
delete_condition = condition
else:
delete_condition = and_(delete_condition, condition)
# 如果有条件,则应用到当前 statement
if delete_condition is not None:
existing_condition = statement.whereclause
if existing_condition is not None:
new_condition = and_(existing_condition, delete_condition)
else:
new_condition = delete_condition
statement = statement.where(new_condition)
return statement
async def scalar(self, statement: Executable, params=None, *, execution_options=None, bind_arguments=None,
**kw: Any, ):
sig = inspect.signature(super().scalar) sig = inspect.signature(super().scalar)
if execution_options is None: if execution_options is None:
default_execution_options = sig.parameters['execution_options'].default default_execution_options = sig.parameters['execution_options'].default
execution_options = default_execution_options execution_options = default_execution_options
delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE
existing_condition = statement.whereclause # 只对 Select 语句添加逻辑删除条件
# 组合条件 if isinstance(statement, Select):
if existing_condition is not None: statement = self._add_logical_delete_condition(statement)
# 使用and_组合现有条件和逻辑删除条件
new_condition = and_(existing_condition, delete_condition)
else:
new_condition = delete_condition
# 应用新条件创建新的Select对象
statement = statement.where(new_condition)
return await super().scalar(statement, params, execution_options=execution_options, return await super().scalar(statement, params, execution_options=execution_options,
bind_arguments=bind_arguments, **kw) bind_arguments=bind_arguments, **kw)
async def execute( async def execute(self, statement: Executable, params=None, *, execution_options=None, bind_arguments=None,
self, **kw: Any, ) -> Result[Any]:
statement: Executable,
params=None,
*,
execution_options=None,
bind_arguments=None,
**kw: Any,
) -> Result[Any]:
sig = inspect.signature(super().execute) sig = inspect.signature(super().execute)
if execution_options is None: if execution_options is None:
default_execution_options = sig.parameters['execution_options'].default default_execution_options = sig.parameters['execution_options'].default
execution_options = default_execution_options execution_options = default_execution_options
# print("type(statement):{}", type(statement))
if isinstance(statement, Select): if isinstance(statement, Select):
# print("这是查询语句,过滤逻辑删除") statement = self._add_logical_delete_condition(statement)
delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE
# 获取现有条件
existing_condition = statement.whereclause
# 组合条件
if existing_condition is not None:
# 使用and_组合现有条件和逻辑删除条件
new_condition = and_(existing_condition, delete_condition)
else:
new_condition = delete_condition
# 应用新条件创建新的Select对象
statement = statement.where(new_condition)
if isinstance(statement, Delete): if isinstance(statement, Delete):
# 检查是否跳过软删除通过execution_options控制
skip_soft_delete = execution_options and execution_options.get("skip_soft_delete", False) skip_soft_delete = execution_options and execution_options.get("skip_soft_delete", False)
if not skip_soft_delete: if not skip_soft_delete:
# 获取表对象
table = statement.table table = statement.table
# 构建更新语句 if hasattr(table, 'columns') and Constant.LOGICAL_DELETE_FIELD in table.columns:
update_stmt = ( update_stmt = (Update(table).where(statement.whereclause).values(
Update(table) **{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE}))
.where(statement.whereclause) # 保留原删除条件
.values(**{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE}) # 设置软删除标记 if statement._returning:
) update_stmt = update_stmt.returning(*statement._returning)
# 如果原删除语句有RETURNING子句也添加到更新语句中 return await super().execute(update_stmt, params=params, execution_options=execution_options,
if statement._returning: bind_arguments=bind_arguments, **kw)
update_stmt = update_stmt.returning(*statement._returning)
# 执行更新语句 result = await super().execute(statement, params=params, execution_options=execution_options,
return await super().execute( bind_arguments=bind_arguments, **kw, )
update_stmt,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
result = await super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
return result return result
# 重写delete方法处理单个对象删除
def delete(self, instance): def delete(self, instance):
from sqlalchemy import inspect from sqlalchemy import inspect as sqlalchemy_inspect
# 检查是否有逻辑删除属性
if hasattr(instance, Constant.LOGICAL_DELETE_FIELD): if hasattr(instance, Constant.LOGICAL_DELETE_FIELD):
# 设置软删除标记 setattr(instance, Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE)
instance.__setattr__(Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE)
# 确保对象在会话中(如果已分离则重新关联)
# 检查对象状态
insp = inspect(instance)
if insp.detached:
# 如果对象是分离的,则重新加入会话
self.add(instance)
elif insp.transient:
# 如果是瞬态对象,也添加到会话
self.add(instance)
# 标记对象为已修改(触发更新) insp = sqlalchemy_inspect(instance)
# self.expire(instance, [Constant.LOGICAL_DELETE_FIELD]) if insp.detached or insp.transient:
self.add(instance)
else: else:
# 如果没有逻辑删除属性,执行标准删除
super().delete(instance) super().delete(instance)
AsyncSessionLocal = async_sessionmaker( AsyncSessionLocal = async_sessionmaker(bind=engine, class_=EnhanceAsyncSession, expire_on_commit=False, # 提交后不使对象过期
bind=engine, autoflush=False # 禁用自动刷新
class_=EnhanceAsyncSession, )
expire_on_commit=False, # 提交后不使对象过期
autoflush=False # 禁用自动刷新
)
# 获取数据库session的独立方法 # 获取数据库session的独立方法
@ -152,7 +144,6 @@ async def get_db_session():
yield session yield session
def with_db_session(session_param_name: str = "session"): def with_db_session(session_param_name: str = "session"):
""" """
一个装饰器用于为异步函数自动注入数据库会话 一个装饰器用于为异步函数自动注入数据库会话
@ -160,6 +151,7 @@ def with_db_session(session_param_name: str = "session"):
Args: Args:
session_param_name: 被装饰函数中用于接收会话的参数名默认为 'session' session_param_name: 被装饰函数中用于接收会话的参数名默认为 'session'
""" """
def decorator(func: Callable[P, T]) -> Callable[P, T]: def decorator(func: Callable[P, T]) -> Callable[P, T]:
# 确保只装饰异步函数 # 确保只装饰异步函数
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
@ -177,8 +169,8 @@ def with_db_session(session_param_name: str = "session"):
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator
return decorator
# 关闭引擎 # 关闭引擎

Loading…
Cancel
Save