From 521bf98c56c66d7da7512075dee35cbb6e61f0bb Mon Sep 17 00:00:00 2001 From: chenzhirong <826531489@qq.com> Date: Fri, 26 Sep 2025 10:04:30 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96unified=5Fresp?= =?UTF-8?q?=E8=A3=85=E9=A5=B0=E5=99=A8=EF=BC=88=E7=BB=9F=E4=B8=80=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=93=8D=E5=BA=94=EF=BC=89=EF=BC=8C=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=20api=E6=96=87=E6=A1=A3=E5=B1=95=E7=A4=BA=20=E7=9A=84=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E7=B1=BB=E5=9E=8B=E4=B8=8E=E7=9C=9F=E5=AE=9E=E7=9A=84?= =?UTF-8?q?=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- entity/dto/__init__.py | 13 ++++++++++++- entity/dto/base.py | 7 ++++--- router/__init__.py | 15 ++++++++++++--- router/user_app.py | 8 ++++++-- service/base_service.py | 26 +++++++++++++------------- 5 files changed, 47 insertions(+), 22 deletions(-) diff --git a/entity/dto/__init__.py b/entity/dto/__init__.py index 57dca47..b52541a 100644 --- a/entity/dto/__init__.py +++ b/entity/dto/__init__.py @@ -1,10 +1,15 @@ from collections import namedtuple +from typing import TypeVar, Generic, Optional from beartype.claw import beartype_this_package +from pydantic import BaseModel beartype_this_package() HttpCode = namedtuple('HttpResp', ['code', 'msg']) +T = TypeVar("T") + + class HttpResp: """HTTP响应结果 """ @@ -24,4 +29,10 @@ class HttpResp: REQUEST_404_ERROR = HttpCode(404, '请求接口不存在') DATA_ALREADY_EXISTS = HttpCode(409, '数据已存在') SYSTEM_ERROR = HttpCode(500, '系统错误') - SYSTEM_TIMEOUT_ERROR = HttpCode(504, '请求超时') \ No newline at end of file + SYSTEM_TIMEOUT_ERROR = HttpCode(504, '请求超时') + + +class ApiResponse(BaseModel, Generic[T]): + code: int = HttpResp.SUCCESS.code + message: str = HttpResp.SUCCESS.msg + data: Optional[T] = None diff --git a/entity/dto/base.py b/entity/dto/base.py index bed90ae..d6f9b4d 100644 --- a/entity/dto/base.py +++ b/entity/dto/base.py @@ -1,8 +1,9 @@ -from typing import Optional, List, Any +from typing import Optional, List, Any, Generic, TypeVar from typing import Union from pydantic import BaseModel, Field +T = TypeVar('T') class BaseTabelDto(BaseModel): id: Optional[str] = None @@ -27,14 +28,14 @@ class BaseRenameReq(BaseModel): name: str -class BasePageResp(BaseModel): +class BasePageResp(BaseModel, Generic[T]): page_number: Optional[int] page_size: Optional[int] page_count: Optional[int] sort: Optional[str] orderby: Optional[str] count: Optional[int] - data: Optional[List[Any]] + data: Optional[List[T]] class Config: arbitrary_types_allowed = True diff --git a/router/__init__.py b/router/__init__.py index 7d95d1c..3d4a329 100644 --- a/router/__init__.py +++ b/router/__init__.py @@ -1,8 +1,8 @@ import inspect import logging -from functools import wraps -from typing import Union, Type, Callable, TypeVar from datetime import datetime +from functools import wraps +from typing import Union, Type, Callable, TypeVar, get_type_hints import pytz from fastapi.encoders import jsonable_encoder @@ -10,15 +10,24 @@ from pydantic import BaseModel from starlette.responses import JSONResponse from config import get_settings -from entity.dto import HttpResp +from entity.dto import HttpResp, ApiResponse from exceptions.base import AppException from utils import get_uuid RT = TypeVar('RT') # 返回类型 + + def unified_resp(func: Callable[..., RT]) -> Callable[..., RT]: """统一响应格式 接口正常返回时,统一响应结果格式 """ + # 获取原始函数的返回类型注解 + hints = get_type_hints(func) + return_type = hints.get('return', None) + + # 修改函数的返回类型注解 + if return_type: + func.__annotations__['return'] = ApiResponse[return_type] @wraps(func) async def wrapper(*args, **kwargs) -> RT: diff --git a/router/user_app.py b/router/user_app.py index c844fc7..d4fa88e 100644 --- a/router/user_app.py +++ b/router/user_app.py @@ -1,5 +1,9 @@ +from typing import List + from fastapi import APIRouter, Query +from entity.db_models import User +from entity.dto.base import BasePageResp from entity.dto.user_dto import UserQueryPageReq, UserQueryReq from router import BaseController, unified_resp from service.user_service import UserService @@ -10,10 +14,10 @@ base_app = BaseController(base_service) @router.get("/page") @unified_resp -async def get_page(req:UserQueryPageReq=Query(...)): +async def get_page(req:UserQueryPageReq=Query(...)) -> BasePageResp[User]: return await base_service.get_by_page(req) @router.get("/list") @unified_resp -async def get_list(req:UserQueryReq=Query(...)): +async def get_list(req:UserQueryReq=Query(...))->List[User]: return await base_service.get_list(req) diff --git a/service/base_service.py b/service/base_service.py index c35cb9b..2d4f3a0 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -1,11 +1,10 @@ -from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional +from typing import Union, Type, List, Any, TypeVar, Generic from fastapi_pagination import Params from fastapi_pagination.ext.sqlalchemy import paginate from pydantic import BaseModel from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import SQLModel from core.global_context import current_session from entity import DbBaseModel @@ -22,6 +21,7 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额 """ T = TypeVar('T', bound=DbBaseModel) + class BaseService(Generic[T]): model: Type[T] # 子类必须指定模型 @@ -66,7 +66,7 @@ class BaseService(Generic[T]): pass @classmethod - async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp: + async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -75,7 +75,7 @@ class BaseService(Generic[T]): @classmethod async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, - dto_model_class: Type[BaseModel] = None)->BasePageResp: + dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]: if not query_params: query_params = {} if not isinstance(query_params, dict): @@ -116,7 +116,7 @@ class BaseService(Generic[T]): }) @classmethod - async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]: + async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -150,7 +150,7 @@ class BaseService(Generic[T]): return [item["id"] for item in exec_result.scalars().all()] @classmethod - async def save(cls, **kwargs)->T: + async def save(cls, **kwargs) -> T: sample_obj = cls.model(**kwargs) session = cls.get_db() session.add(sample_obj) @@ -158,7 +158,7 @@ class BaseService(Generic[T]): return sample_obj @classmethod - async def insert_many(cls, data_list, batch_size=100)->None: + async def insert_many(cls, data_list, batch_size=100) -> None: async with cls.get_db() as session: for d in data_list: if not d.get("id", None): @@ -168,27 +168,27 @@ class BaseService(Generic[T]): session.add_all(data_list[i: i + batch_size]) @classmethod - async def update_by_id(cls, pid, data)-> int: + async def update_by_id(cls, pid, data) -> int: update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) session = cls.get_db() result = await session.execute(update_stmt) return result.rowcount @classmethod - async def update_many_by_id(cls, data_list)->None: + async def update_many_by_id(cls, data_list) -> None: async with cls.get_db() as session: for data in data_list: stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) await session.execute(stmt) @classmethod - async def get_by_id(cls, pid)->T: + async def get_by_id(cls, pid) -> T: stmt = cls.model.select().where(cls.model.id == pid) session = cls.get_db() return await session.scalar(stmt) @classmethod - async def get_by_ids(cls, pids, cols=None)->List[T]: + async def get_by_ids(cls, pids, cols=None) -> List[T]: if cols: objs = cls.model.select(*cols) else: @@ -199,14 +199,14 @@ class BaseService(Generic[T]): return list(result.all()) @classmethod - async def delete_by_id(cls, pid)-> int: + async def delete_by_id(cls, pid) -> int: del_stmt = cls.model.delete().where(cls.model.id == pid) session = cls.get_db() exec_result = await session.execute(del_stmt) return exec_result.rowcount @classmethod - async def delete_by_ids(cls, pids)-> int: + async def delete_by_ids(cls, pids) -> int: session = cls.get_db() del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) result = await session.execute(del_stmt)