diff --git a/entity/__init__.py b/entity/__init__.py index 098388f..d5fcf55 100644 --- a/entity/__init__.py +++ b/entity/__init__.py @@ -3,144 +3,136 @@ import inspect from contextlib import asynccontextmanager from typing import Any, ParamSpec, TypeVar, Callable -from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ +from sqlalchemy import Executable, Result, and_, Select, Delete, Update from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession +from sqlalchemy.sql.selectable import Subquery from common.constant import Constant from common.global_enums import IsDelete from config import settings -from entity.base_entity import DbBaseModel P = ParamSpec('P') T = TypeVar('T') -engine = create_async_engine( - settings.database_url, - echo=False, # 打印SQL日志(生产环境建议关闭) - pool_size=10, # 连接池大小 - max_overflow=20, # 最大溢出连接数 - pool_recycle=3600, # 连接回收时间(秒),解决MySQL超时断开问题【4†source】【5†source】 -) +engine = create_async_engine(settings.database_url, echo=False, # 打印SQL日志(生产环境建议关闭) + pool_size=10, # 连接池大小 + max_overflow=20, # 最大溢出连接数 + pool_recycle=3600, # 连接回收时间(秒),解决MySQL超时断开问题【4†source】【5†source】 + ) # 创建异步会话工厂 + + class EnhanceAsyncSession(AsyncSession): - async def scalar(self, statement: Executable, - params=None, - *, - execution_options=None, - bind_arguments=None, - **kw: Any, ): + def _add_logical_delete_condition(self, statement: Select) -> Select: + """ + 为 Select 语句添加逻辑删除条件 + 支持递归处理 fastapi-pagination 生成的子查询包装 + """ + # 特殊处理:如果当前查询只有一个子查询,直接修改其内部元素 + # 这样可以避免产生重复的别名 + if len(statement.froms) == 1 and isinstance(statement.froms[0], Subquery): + subquery = statement.froms[0] + # 递归处理子查询内部 + processed_inner = self._add_logical_delete_condition(subquery.element) + # 如果内部有修改,直接替换 subquery 的 element + if processed_inner is not subquery.element: + subquery.element = processed_inner + return statement + + # 处理当前层的表 + delete_condition = None + delete_field = Constant.LOGICAL_DELETE_FIELD + + for from_obj in statement.froms: + # 跳过子查询(因为会通过上面的逻辑递归处理) + if isinstance(from_obj, Subquery): + continue + + # 只处理 Table 对象 + if hasattr(from_obj, 'columns') and delete_field in from_obj.columns: + # 使用表对象上的列 + condition = from_obj.columns[delete_field] == IsDelete.NO_DELETE + + if delete_condition is None: + delete_condition = condition + else: + delete_condition = and_(delete_condition, condition) + + # 如果有条件,则应用到当前 statement + if delete_condition is not None: + existing_condition = statement.whereclause + if existing_condition is not None: + new_condition = and_(existing_condition, delete_condition) + else: + new_condition = delete_condition + + statement = statement.where(new_condition) + + return statement + + async def scalar(self, statement: Executable, params=None, *, execution_options=None, bind_arguments=None, + **kw: Any, ): sig = inspect.signature(super().scalar) if execution_options is None: default_execution_options = sig.parameters['execution_options'].default execution_options = default_execution_options - delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE - existing_condition = statement.whereclause - # 组合条件 - if existing_condition is not None: - # 使用and_组合现有条件和逻辑删除条件 - new_condition = and_(existing_condition, delete_condition) - else: - new_condition = delete_condition - # 应用新条件(创建新的Select对象) - statement = statement.where(new_condition) + + # 只对 Select 语句添加逻辑删除条件 + if isinstance(statement, Select): + statement = self._add_logical_delete_condition(statement) + return await super().scalar(statement, params, execution_options=execution_options, bind_arguments=bind_arguments, **kw) - async def execute( - self, - statement: Executable, - params=None, - *, - execution_options=None, - bind_arguments=None, - **kw: Any, - ) -> Result[Any]: + async def execute(self, statement: Executable, params=None, *, execution_options=None, bind_arguments=None, + **kw: Any, ) -> Result[Any]: sig = inspect.signature(super().execute) if execution_options is None: default_execution_options = sig.parameters['execution_options'].default execution_options = default_execution_options - # print("type(statement):{}", type(statement)) + if isinstance(statement, Select): - # print("这是查询语句,过滤逻辑删除") - delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE - # 获取现有条件 - existing_condition = statement.whereclause - # 组合条件 - if existing_condition is not None: - # 使用and_组合现有条件和逻辑删除条件 - new_condition = and_(existing_condition, delete_condition) - else: - new_condition = delete_condition - # 应用新条件(创建新的Select对象) - statement = statement.where(new_condition) + statement = self._add_logical_delete_condition(statement) + if isinstance(statement, Delete): - # 检查是否跳过软删除(通过execution_options控制) skip_soft_delete = execution_options and execution_options.get("skip_soft_delete", False) if not skip_soft_delete: - # 获取表对象 table = statement.table - # 构建更新语句 - update_stmt = ( - Update(table) - .where(statement.whereclause) # 保留原删除条件 - .values(**{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE}) # 设置软删除标记 - ) - - # 如果原删除语句有RETURNING子句,也添加到更新语句中 - if statement._returning: - update_stmt = update_stmt.returning(*statement._returning) - # 执行更新语句 - return await super().execute( - update_stmt, - params=params, - execution_options=execution_options, - bind_arguments=bind_arguments, - **kw - ) - result = await super().execute( - statement, - params=params, - execution_options=execution_options, - bind_arguments=bind_arguments, - **kw, - ) + if hasattr(table, 'columns') and Constant.LOGICAL_DELETE_FIELD in table.columns: + update_stmt = (Update(table).where(statement.whereclause).values( + **{Constant.LOGICAL_DELETE_FIELD: IsDelete.DELETE})) + + if statement._returning: + update_stmt = update_stmt.returning(*statement._returning) + + return await super().execute(update_stmt, params=params, execution_options=execution_options, + bind_arguments=bind_arguments, **kw) + + result = await super().execute(statement, params=params, execution_options=execution_options, + bind_arguments=bind_arguments, **kw, ) return result - # 重写delete方法处理单个对象删除 def delete(self, instance): - from sqlalchemy import inspect - # 检查是否有逻辑删除属性 + from sqlalchemy import inspect as sqlalchemy_inspect + if hasattr(instance, Constant.LOGICAL_DELETE_FIELD): - # 设置软删除标记 - instance.__setattr__(Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE) - # 确保对象在会话中(如果已分离则重新关联) - # 检查对象状态 - insp = inspect(instance) - if insp.detached: - # 如果对象是分离的,则重新加入会话 - self.add(instance) - elif insp.transient: - # 如果是瞬态对象,也添加到会话 - self.add(instance) + setattr(instance, Constant.LOGICAL_DELETE_FIELD, IsDelete.DELETE) - # 标记对象为已修改(触发更新) - # self.expire(instance, [Constant.LOGICAL_DELETE_FIELD]) + insp = sqlalchemy_inspect(instance) + if insp.detached or insp.transient: + self.add(instance) else: - # 如果没有逻辑删除属性,执行标准删除 super().delete(instance) -AsyncSessionLocal = async_sessionmaker( - bind=engine, - class_=EnhanceAsyncSession, - expire_on_commit=False, # 提交后不使对象过期 - autoflush=False # 禁用自动刷新 -) +AsyncSessionLocal = async_sessionmaker(bind=engine, class_=EnhanceAsyncSession, expire_on_commit=False, # 提交后不使对象过期 + autoflush=False # 禁用自动刷新 + ) # 获取数据库session的独立方法 @@ -152,7 +144,6 @@ async def get_db_session(): yield session - def with_db_session(session_param_name: str = "session"): """ 一个装饰器,用于为异步函数自动注入数据库会话。 @@ -160,6 +151,7 @@ def with_db_session(session_param_name: str = "session"): Args: session_param_name: 被装饰函数中,用于接收会话的参数名,默认为 'session'。 """ + def decorator(func: Callable[P, T]) -> Callable[P, T]: # 确保只装饰异步函数 if not inspect.iscoroutinefunction(func): @@ -177,8 +169,8 @@ def with_db_session(session_param_name: str = "session"): return await func(*args, **kwargs) return wrapper - return decorator + return decorator # 关闭引擎