|
|
|
|
@ -3,20 +3,18 @@ import inspect
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from typing import Any, ParamSpec, TypeVar, Callable
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
|
|
|
|
|
from sqlalchemy import Executable, Result, and_, Select, Delete, Update
|
|
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
|
|
|
|
from sqlalchemy.sql.selectable import Subquery
|
|
|
|
|
|
|
|
|
|
from common.constant import Constant
|
|
|
|
|
from common.global_enums import IsDelete
|
|
|
|
|
from config import settings
|
|
|
|
|
|
|
|
|
|
from entity.base_entity import DbBaseModel
|
|
|
|
|
P = ParamSpec('P')
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
|
|
|
|
engine = create_async_engine(
|
|
|
|
|
settings.database_url,
|
|
|
|
|
echo=False, # 打印SQL日志(生产环境建议关闭)
|
|
|
|
|
engine = create_async_engine(settings.database_url, echo=False, # 打印SQL日志(生产环境建议关闭)
|
|
|
|
|
pool_size=10, # 连接池大小
|
|
|
|
|
max_overflow=20, # 最大溢出连接数
|
|
|
|
|
pool_recycle=3600, # 连接回收时间(秒),解决MySQL超时断开问题【4†source】【5†source】
|
|
|
|
|
@ -24,121 +22,115 @@ engine = create_async_engine(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建异步会话工厂
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnhanceAsyncSession(AsyncSession):
|
|
|
|
|
async def scalar(self, statement: Executable,
|
|
|
|
|
params=None,
|
|
|
|
|
*,
|
|
|
|
|
execution_options=None,
|
|
|
|
|
bind_arguments=None,
|
|
|
|
|
**kw: Any, ):
|
|
|
|
|
|
|
|
|
|
sig = inspect.signature(super().scalar)
|
|
|
|
|
if execution_options is None:
|
|
|
|
|
default_execution_options = sig.parameters['execution_options'].default
|
|
|
|
|
execution_options = default_execution_options
|
|
|
|
|
delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE
|
|
|
|
|
def _add_logical_delete_condition(self, statement: Select) -> Select:
|
|
|
|
|
"""
|
|
|
|
|
为 Select 语句添加逻辑删除条件
|
|
|
|
|
支持递归处理 fastapi-pagination 生成的子查询包装
|
|
|
|
|
"""
|
|
|
|
|
# 特殊处理:如果当前查询只有一个子查询,直接修改其内部元素
|
|
|
|
|
# 这样可以避免产生重复的别名
|
|
|
|
|
if len(statement.froms) == 1 and isinstance(statement.froms[0], Subquery):
|
|
|
|
|
subquery = statement.froms[0]
|
|
|
|
|
# 递归处理子查询内部
|
|
|
|
|
processed_inner = self._add_logical_delete_condition(subquery.element)
|
|
|
|
|
# 如果内部有修改,直接替换 subquery 的 element
|
|
|
|
|
if processed_inner is not subquery.element:
|
|
|
|
|
subquery.element = processed_inner
|
|
|
|
|
return statement
|
|
|
|
|
|
|
|
|
|
# 处理当前层的表
|
|
|
|
|
delete_condition = None
|
|
|
|
|
delete_field = Constant.LOGICAL_DELETE_FIELD
|
|
|
|
|
|
|
|
|
|
for from_obj in statement.froms:
|
|
|
|
|
# 跳过子查询(因为会通过上面的逻辑递归处理)
|
|
|
|
|
if isinstance(from_obj, Subquery):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 只处理 Table 对象
|
|
|
|
|
if hasattr(from_obj, 'columns') and delete_field in from_obj.columns:
|
|
|
|
|
# 使用表对象上的列
|
|
|
|
|
condition = from_obj.columns[delete_field] == IsDelete.NO_DELETE
|
|
|
|
|
|
|
|
|
|
if delete_condition is None:
|
|
|
|
|
delete_condition = condition
|
|
|
|
|
else:
|
|
|
|
|
delete_condition = and_(delete_condition, condition)
|
|
|
|
|
|
|
|
|
|
# 如果有条件,则应用到当前 statement
|
|
|
|
|
if delete_condition is not None:
|
|
|
|
|
existing_condition = statement.whereclause
|
|
|
|
|
# 组合条件
|
|
|
|
|
if existing_condition is not None:
|
|
|
|
|
# 使用and_组合现有条件和逻辑删除条件
|
|
|
|
|
new_condition = and_(existing_condition, delete_condition)
|
|
|
|
|
else:
|
|
|
|
|
new_condition = delete_condition
|
|
|
|
|
# 应用新条件(创建新的Select对象)
|
|
|
|
|
|
|
|
|
|
statement = statement.where(new_condition)
|
|
|
|
|
|
|
|
|
|
return statement
|
|
|
|
|
|
|
|
|
|
async def scalar(self, statement: Executable, params=None, *, execution_options=None, bind_arguments=None,
|
|
|
|
|
**kw: Any, ):
|
|
|
|
|
sig = inspect.signature(super().scalar)
|
|
|
|
|
if execution_options is None:
|
|
|
|
|
default_execution_options = sig.parameters['execution_options'].default
|
|
|
|
|
execution_options = default_execution_options
|
|
|
|
|
|
|
|
|
|
# 只对 Select 语句添加逻辑删除条件
|
|
|
|
|
if isinstance(statement, Select):
|
|
|
|
|
statement = self._add_logical_delete_condition(statement)
|
|
|
|
|
|
|
|
|
|
return await super().scalar(statement, params, execution_options=execution_options,
|
|
|
|
|
bind_arguments=bind_arguments, **kw)
|
|
|
|
|
|
|
|
|
|
async def execute(
|
|
|
|
|
self,
|
|
|
|
|
statement: Executable,
|
|
|
|
|
params=None,
|
|
|
|
|
*,
|
|
|
|
|
execution_options=None,
|
|
|
|
|
bind_arguments=None,
|
|
|
|
|
**kw: Any,
|
|
|
|
|
) -> Result[Any]:
|
|
|
|
|
async def execute(self, statement: Executable, params=None, *, execution_options=None, bind_arguments=None,
|
|
|
|
|
**kw: Any, ) -> Result[Any]:
|
|
|
|
|
sig = inspect.signature(super().execute)
|
|
|
|
|
if execution_options is None:
|
|
|
|
|
default_execution_options = sig.parameters['execution_options'].default
|
|
|
|
|
execution_options = default_execution_options
|
|
|
|
|
# print("type(statement):{}", type(statement))
|
|
|
|
|
|
|
|
|
|
if isinstance(statement, Select):
|
|
|
|
|
# print("这是查询语句,过滤逻辑删除")
|
|
|
|
|
delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE
|
|
|
|
|
# 获取现有条件
|
|
|
|
|
existing_condition = statement.whereclause
|
|
|
|
|
# 组合条件
|
|
|
|
|
if existing_condition is not None:
|
|
|
|
|
# 使用and_组合现有条件和逻辑删除条件
|
|
|
|
|
new_condition = and_(existing_condition, delete_condition)
|
|
|
|
|
else:
|
|
|
|
|
new_condition = delete_condition
|
|
|
|
|
# 应用新条件(创建新的Select对象)
|
|
|
|
|
statement = statement.where(new_condition)
|
|
|
|
|
statement = self._add_logical_delete_condition(statement)
|
|
|
|
|
|
|
|
|
|
if isinstance(statement, Delete):
|
|
|
|
|
# 检查是否跳过软删除(通过execution_options控制)
|
|
|
|
|
skip_soft_delete = execution_options and execution_options.get("skip_soft_delete", False)
|
|
|
|
|
|
|
|
|
|
if not skip_soft_delete:
|
|
|
|
|
# 获取表对象
|
|
|
|
|
table = statement.table
|
|
|
|
|
|
|
|
|
|
# 构建更新语句
|
|
|
|
|
update_stmt = (
|
|
|
|
|
Update(table)
|
|
|
|
|
.where(statement.whereclause) # 保留原删除条件
|
|
|
|
|
.values(**{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE}) # 设置软删除标记
|
|
|
|
|
)
|
|
|
|
|
if hasattr(table, 'columns') and Constant.LOGICAL_DELETE_FIELD in table.columns:
|
|
|
|
|
update_stmt = (Update(table).where(statement.whereclause).values(
|
|
|
|
|
**{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE}))
|
|
|
|
|
|
|
|
|
|
# 如果原删除语句有RETURNING子句,也添加到更新语句中
|
|
|
|
|
if statement._returning:
|
|
|
|
|
update_stmt = update_stmt.returning(*statement._returning)
|
|
|
|
|
# 执行更新语句
|
|
|
|
|
return await super().execute(
|
|
|
|
|
update_stmt,
|
|
|
|
|
params=params,
|
|
|
|
|
execution_options=execution_options,
|
|
|
|
|
bind_arguments=bind_arguments,
|
|
|
|
|
**kw
|
|
|
|
|
)
|
|
|
|
|
result = await super().execute(
|
|
|
|
|
statement,
|
|
|
|
|
params=params,
|
|
|
|
|
execution_options=execution_options,
|
|
|
|
|
bind_arguments=bind_arguments,
|
|
|
|
|
**kw,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return await super().execute(update_stmt, params=params, execution_options=execution_options,
|
|
|
|
|
bind_arguments=bind_arguments, **kw)
|
|
|
|
|
|
|
|
|
|
result = await super().execute(statement, params=params, execution_options=execution_options,
|
|
|
|
|
bind_arguments=bind_arguments, **kw, )
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
# 重写delete方法处理单个对象删除
|
|
|
|
|
def delete(self, instance):
|
|
|
|
|
from sqlalchemy import inspect
|
|
|
|
|
# 检查是否有逻辑删除属性
|
|
|
|
|
from sqlalchemy import inspect as sqlalchemy_inspect
|
|
|
|
|
|
|
|
|
|
if hasattr(instance, Constant.LOGICAL_DELETE_FIELD):
|
|
|
|
|
# 设置软删除标记
|
|
|
|
|
instance.__setattr__(Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE)
|
|
|
|
|
# 确保对象在会话中(如果已分离则重新关联)
|
|
|
|
|
# 检查对象状态
|
|
|
|
|
insp = inspect(instance)
|
|
|
|
|
if insp.detached:
|
|
|
|
|
# 如果对象是分离的,则重新加入会话
|
|
|
|
|
self.add(instance)
|
|
|
|
|
elif insp.transient:
|
|
|
|
|
# 如果是瞬态对象,也添加到会话
|
|
|
|
|
self.add(instance)
|
|
|
|
|
setattr(instance, Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE)
|
|
|
|
|
|
|
|
|
|
# 标记对象为已修改(触发更新)
|
|
|
|
|
# self.expire(instance, [Constant.LOGICAL_DELETE_FIELD])
|
|
|
|
|
insp = sqlalchemy_inspect(instance)
|
|
|
|
|
if insp.detached or insp.transient:
|
|
|
|
|
self.add(instance)
|
|
|
|
|
else:
|
|
|
|
|
# 如果没有逻辑删除属性,执行标准删除
|
|
|
|
|
super().delete(instance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
|
|
|
bind=engine,
|
|
|
|
|
class_=EnhanceAsyncSession,
|
|
|
|
|
expire_on_commit=False, # 提交后不使对象过期
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(bind=engine, class_=EnhanceAsyncSession, expire_on_commit=False, # 提交后不使对象过期
|
|
|
|
|
autoflush=False # 禁用自动刷新
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -152,7 +144,6 @@ async def get_db_session():
|
|
|
|
|
yield session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def with_db_session(session_param_name: str = "session"):
|
|
|
|
|
"""
|
|
|
|
|
一个装饰器,用于为异步函数自动注入数据库会话。
|
|
|
|
|
@ -160,6 +151,7 @@ def with_db_session(session_param_name: str = "session"):
|
|
|
|
|
Args:
|
|
|
|
|
session_param_name: 被装饰函数中,用于接收会话的参数名,默认为 'session'。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
|
|
|
# 确保只装饰异步函数
|
|
|
|
|
if not inspect.iscoroutinefunction(func):
|
|
|
|
|
@ -177,8 +169,8 @@ def with_db_session(session_param_name: str = "session"):
|
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 关闭引擎
|
|
|
|
|
|