From eea9687c50cdbc677187653ad8923d84d5594e7f Mon Sep 17 00:00:00 2001 From: chenzhirong <826531489@qq.com> Date: Mon, 12 Jan 2026 00:30:59 +0800 Subject: [PATCH] =?UTF-8?q?perf=EF=BC=9A=E4=BC=98=E5=8C=96=E9=83=A8?= =?UTF-8?q?=E5=88=86=E5=AD=97=E6=AE=B5=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- entity/base_entity.py | 7 +- router/__init__.py | 30 +++---- service/base_service.py | 172 +++++++++++++++++++++------------------- 3 files changed, 105 insertions(+), 104 deletions(-) diff --git a/entity/base_entity.py b/entity/base_entity.py index 0d796ae..5a6a9a9 100644 --- a/entity/base_entity.py +++ b/entity/base_entity.py @@ -28,11 +28,10 @@ class DbBaseModel(SQLModel, table=False): # class Config: # arbitrary_types_allowed = True @classmethod - def select(cls, fields=None): + def select(cls, fields: list = None): if fields is None: - fields = cls - return Select(fields) - + return Select(cls) + return Select(*fields) @classmethod def delete(cls): diff --git a/router/__init__.py b/router/__init__.py index 3d4a329..3d7a81b 100644 --- a/router/__init__.py +++ b/router/__init__.py @@ -14,6 +14,7 @@ from entity.dto import HttpResp, ApiResponse from exceptions.base import AppException from utils import get_uuid +__all__ = ["unified_resp", "BaseController"] RT = TypeVar('RT') # 返回类型 @@ -35,6 +36,7 @@ def unified_resp(func: Callable[..., RT]) -> Callable[..., RT]: resp = await func(*args, **kwargs) or [] else: resp = func(*args, **kwargs) or [] + return JSONResponse( content=jsonable_encoder( # 正常请求响应 @@ -58,33 +60,23 @@ class BaseController: async def base_page(self, req: Union[dict, BaseModel], dto_class: Type[BaseModel] = None): if not isinstance(req, dict): req = req.model_dump() - result = await self.service.get_by_page(req) - datas = result.data - if datas and dto_class: - result.data = self.service.entity_conversion_dto(datas, dto_class) + result = await self.service.get_by_page(req,dto_class) + # datas = result.data + # if datas and dto_class: + # result.data = self.service.entity_conversion_dto(datas, dto_class) return result async def base_list(self, req: Union[dict, BaseModel], dto_class: Type[BaseModel] = None): if not isinstance(req, dict): req = req.model_dump() - datas = await self.service.get_list(req) - if datas and dto_class: - datas = self.service.entity_conversion_dto(datas, dto_class) + datas = await self.service.get_list(req,dto_class) return datas - async def get_all(self, dto_class: Type[BaseModel] = None): - result = await self.service.get_all() - if dto_class: - result = self.service.entity_conversion_dto(result, dto_class) - return result async def get_by_id(self, id: str, dto_class: Type[BaseModel] = None): - data = await self.service.get_by_id(id) - if not data: + result = await self.service.get_by_id(id,dto_class) + if not result: raise AppException(f"不存在 id 为{id}的数据") - result = data.to_dict() - if dto_class: - result = self.service.entity_conversion_dto(result, dto_class) return result async def add(self, req: Union[dict, BaseModel]): @@ -102,7 +94,7 @@ class BaseController: db_query_data = await self.service.get_by_id(id) if not db_query_data: raise AppException(f"数据不存在") - self.service.check_base_permission(db_query_data) + await self.service.check_base_permission(db_query_data) try: return await self.service.delete_by_id(id) except Exception as e: @@ -117,7 +109,7 @@ class BaseController: db_query_data = await self.service.get_by_id(data_id) if not db_query_data: raise AppException(f"数据不存在") - self.service.check_base_permission(db_query_data) + await self.service.check_base_permission(db_query_data) try: return await self.service.update_by_id(data_id, req) diff --git a/service/base_service.py b/service/base_service.py index 7b4b6cd..2d4a357 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -29,7 +29,7 @@ class BaseService(Generic[T]): def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None): if stmt is None: if fields: - stmt = cls.model.select(*fields) + stmt = cls.model.select(fields) else: stmt = cls.model.select() for key, value in query_params.items(): @@ -48,6 +48,9 @@ class BaseService(Generic[T]): @classmethod def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[ BaseModel, List[BaseModel]]: + """ + 数据脱敏 + """ dto_list = [] if not isinstance(entity_data, list): return dto(**entity_data.model_dump()) @@ -59,23 +62,27 @@ class BaseService(Generic[T]): return dto_list @classmethod - def check_base_permission(cls, daba: Any): + async def check_base_permission(cls, daba: Any): # todo pass @classmethod - async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]: + async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq], dto_model_class: Type[BaseModel] = None) -> \ + BasePageResp[T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v not in [None, ""]} - query_stmt = cls.get_query_stmt(query_params) - return await cls.auto_page(query_stmt, query_params) + fields = None + if dto_model_class is not None: + fields = [getattr(cls.model, key) for key in dto_model_class.model_fields.keys()] + query_stmt = cls.get_query_stmt(query_params, fields=fields) + + return await cls.auto_page(query_stmt, query_params, dto_model_class) @classmethod @with_db_session() async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, - dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> \ - BasePageResp[T]: + dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> BasePageResp[T]: if not query_params: query_params = {} if not isinstance(query_params, dict): @@ -84,87 +91,45 @@ class BaseService(Generic[T]): page_size = query_params.get("page_size", 12) sort = query_params.get("sort", "desc") orderby = query_params.get("orderby", "created_time") - data_count = None - if data_count == 0: - return BasePageResp(**{ - "page_number": page_number, - "page_size": page_size, - "count": data_count, - "sort": sort, - "orderby": orderby, - "data": [], - }) + if sort == "desc": query_stmt = query_stmt.order_by(getattr(cls.model, orderby).desc()) else: query_stmt = query_stmt.order_by(getattr(cls.model, orderby).asc()) - query_page_result = await paginate(session, - query_stmt, - Params(page=page_number, size=page_size)) + + query_page_result = await paginate(session, query_stmt, Params(page=page_number, size=page_size)) result = query_page_result.items if dto_model_class is not None: - result = [dto_model_class(**item) for item in result] - return BasePageResp(**{ - "page_number": page_number, - "page_size": page_size, - "page_count": query_page_result.pages, - "count": query_page_result.total, - "sort": sort, - "orderby": orderby, - "data": result, - }) + # 使用 row._mapping 将 Row 对象转换为可解包的字典 + result = [dto_model_class(**dict(row._mapping)) for row in result] + return BasePageResp( + **{"page_number": page_number, "page_size": page_size, "page_count": query_page_result.pages, + "count": query_page_result.total, "sort": sort, "orderby": orderby, "data": result, }) @classmethod @with_db_session() - async def get_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> List[ - T]: - if not isinstance(query_params, dict): - query_params = query_params.model_dump() - query_params = {k: v for k, v in query_params.items() if v not in [None, ""]} - sort = query_params.get("sort", "desc") - orderby = query_params.get("orderby", "created_time") - query_stmt = cls.get_query_stmt(query_params) - field = getattr(cls.model, orderby) - if sort == "desc": - query_stmt = query_stmt.order_by(field.desc()) - else: - query_stmt = query_stmt.order_by(field.asc()) - if query_params.get("limit", None) is not None: - query_stmt = query_stmt.limit(query_params.get("limit")) + async def get_list(cls, query_params: Union[dict, BaseQueryReq], dto_model_class: Type[BaseModel] = None, *, + session: Optional[AsyncSession] = None) -> List[T] | List[BaseModel]: + """ + query_params: 参数字典 or 参数请求模型 + dto_model_class: 输出类型 + session: 数据库会话---支持传递以便于事务管控 + 获取数据集合 + """ + query_stmt = cls.build_query(query_params, dto_model_class=dto_model_class) exec_result = await session.execute(query_stmt) - return list(exec_result.scalars().all()) - - @classmethod - @with_db_session() - async def get_list_json(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \ - List[ - T]: - resp_list = await cls.get_list(query_params, session=session) - - return [i.model_dump() for i in resp_list] + return cls.parse_result(exec_result, dto_model_class=dto_model_class) @classmethod @with_db_session() - async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \ - List[str]: - if not isinstance(query_params, dict): - query_params = query_params.model_dump() - query_params = {k: v for k, v in query_params.items() if v is not None} - sort = query_params.get("sort", "desc") - orderby = query_params.get("orderby", "created_time") - query_stmt = cls.model.select(cls.model.id) - query_stmt = cls.get_query_stmt(query_params, query_stmt) - if sort == "desc": - query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc()) - else: - query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc()) - exec_result = await session.execute(query_stmt) - return [item["id"] for item in exec_result.scalars().all()] + async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None): + query_stmt = cls.build_query(query_params, fields=[cls.model.id]) + exec_result = await session.scalars(query_stmt) + return list(exec_result) @classmethod @with_db_session() async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T: - sample_obj = cls.model(**kwargs) session.add(sample_obj) await session.flush() @@ -206,7 +171,6 @@ class BaseService(Generic[T]): @classmethod @with_db_session() async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T: - stmt = cls.model.select().where(cls.model.id == pid) return await session.scalar(stmt) @@ -222,15 +186,12 @@ class BaseService(Generic[T]): @classmethod @with_db_session() - async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]: - - if cols: - objs = cls.model.select(*cols) - else: - objs = cls.model.select() - stmt = objs.where(cls.model.id.in_(pids)) - result = await session.scalars(stmt) - return list(result.all()) + async def get_by_ids(cls, pids, dto_model_class: Type[BaseModel] = None, *, + session: Optional[AsyncSession] = None) -> List[T]: + stmt = cls.build_query({}, dto_model_class=dto_model_class) + stmt = stmt.where(cls.model.id.in_(pids)) + exec_result = await session.execute(stmt) + return cls.parse_result(exec_result, dto_model_class=dto_model_class) @classmethod @with_db_session() @@ -271,3 +232,52 @@ class BaseService(Generic[T]): @classmethod async def is_exist(cls, query_params: dict = None): return await cls.get_data_count(query_params) > 0 + + @classmethod + def build_query(cls, query_params: Union[dict, BaseQueryReq], *, dto_model_class: Type[BaseModel] = None, + fields: List = None): + """ + 可选dto_model_class、fields + 优先级dto_model_class>fields + 如果传递了dto_model_class则无需传递fields,fields会被dto_model_class覆盖 + """ + if not isinstance(query_params, dict): + query_params = query_params.model_dump() + query_params = {k: v for k, v in query_params.items() if v not in [None, ""]} + sort = query_params.get("sort", "desc").lower() + orderby = query_params.get("orderby", "created_time").lower() + if dto_model_class is not None: + # 安全映射:只获取 DTO 中存在且 Model 中也有的字段 + # 避免 DTO 包含计算字段时 getattr 报错 + db_columns = [] + for key in dto_model_class.model_fields.keys(): + if hasattr(cls.model, key): + db_columns.append(getattr(cls.model, key)) + fields = db_columns + + query_stmt = cls.get_query_stmt(query_params, fields=fields) + + if hasattr(cls.model, orderby): + order_field = getattr(cls.model, orderby) + else: + # 如果传入的 orderby 字段不存在,回退到默认排序,防止报错 + order_field = cls.model.created_time + + # 根据xxx字段排序 + if sort == "desc": + query_stmt = query_stmt.order_by(order_field.desc()) + else: + query_stmt = query_stmt.order_by(order_field.asc()) + + if query_params.get("limit", None) is not None: + query_stmt = query_stmt.limit(query_params.get("limit")) + return query_stmt + + @classmethod + def parse_result(cls, exec_result, dto_model_class: Type[BaseModel] = None): + """ + 将数据库执行结果解析为 DTO 列表或 实体列表 + """ + if dto_model_class is not None: + return [dto_model_class(**dict(row._mapping)) for row in exec_result.all()] + return list(exec_result.scalars().all())