perf: 优化unified_resp装饰器(统一接口响应),确保 api文档展示 的响应类型与真实的一致

main
chenzhirong 4 months ago
parent 9a211f9484
commit 521bf98c56

@ -1,10 +1,15 @@
from collections import namedtuple from collections import namedtuple
from typing import TypeVar, Generic, Optional
from beartype.claw import beartype_this_package from beartype.claw import beartype_this_package
from pydantic import BaseModel
beartype_this_package() beartype_this_package()
HttpCode = namedtuple('HttpResp', ['code', 'msg']) HttpCode = namedtuple('HttpResp', ['code', 'msg'])
T = TypeVar("T")
class HttpResp: class HttpResp:
"""HTTP响应结果 """HTTP响应结果
""" """
@ -24,4 +29,10 @@ class HttpResp:
REQUEST_404_ERROR = HttpCode(404, '请求接口不存在') REQUEST_404_ERROR = HttpCode(404, '请求接口不存在')
DATA_ALREADY_EXISTS = HttpCode(409, '数据已存在') DATA_ALREADY_EXISTS = HttpCode(409, '数据已存在')
SYSTEM_ERROR = HttpCode(500, '系统错误') SYSTEM_ERROR = HttpCode(500, '系统错误')
SYSTEM_TIMEOUT_ERROR = HttpCode(504, '请求超时') SYSTEM_TIMEOUT_ERROR = HttpCode(504, '请求超时')
class ApiResponse(BaseModel, Generic[T]):
code: int = HttpResp.SUCCESS.code
message: str = HttpResp.SUCCESS.msg
data: Optional[T] = None

@ -1,8 +1,9 @@
from typing import Optional, List, Any from typing import Optional, List, Any, Generic, TypeVar
from typing import Union from typing import Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
T = TypeVar('T')
class BaseTabelDto(BaseModel): class BaseTabelDto(BaseModel):
id: Optional[str] = None id: Optional[str] = None
@ -27,14 +28,14 @@ class BaseRenameReq(BaseModel):
name: str name: str
class BasePageResp(BaseModel): class BasePageResp(BaseModel, Generic[T]):
page_number: Optional[int] page_number: Optional[int]
page_size: Optional[int] page_size: Optional[int]
page_count: Optional[int] page_count: Optional[int]
sort: Optional[str] sort: Optional[str]
orderby: Optional[str] orderby: Optional[str]
count: Optional[int] count: Optional[int]
data: Optional[List[Any]] data: Optional[List[T]]
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

@ -1,8 +1,8 @@
import inspect import inspect
import logging import logging
from functools import wraps
from typing import Union, Type, Callable, TypeVar
from datetime import datetime from datetime import datetime
from functools import wraps
from typing import Union, Type, Callable, TypeVar, get_type_hints
import pytz import pytz
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@ -10,15 +10,24 @@ from pydantic import BaseModel
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from config import get_settings from config import get_settings
from entity.dto import HttpResp from entity.dto import HttpResp, ApiResponse
from exceptions.base import AppException from exceptions.base import AppException
from utils import get_uuid from utils import get_uuid
RT = TypeVar('RT') # 返回类型 RT = TypeVar('RT') # 返回类型
def unified_resp(func: Callable[..., RT]) -> Callable[..., RT]: def unified_resp(func: Callable[..., RT]) -> Callable[..., RT]:
"""统一响应格式 """统一响应格式
接口正常返回时,统一响应结果格式 接口正常返回时,统一响应结果格式
""" """
# 获取原始函数的返回类型注解
hints = get_type_hints(func)
return_type = hints.get('return', None)
# 修改函数的返回类型注解
if return_type:
func.__annotations__['return'] = ApiResponse[return_type]
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs) -> RT: async def wrapper(*args, **kwargs) -> RT:

@ -1,5 +1,9 @@
from typing import List
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from entity.db_models import User
from entity.dto.base import BasePageResp
from entity.dto.user_dto import UserQueryPageReq, UserQueryReq from entity.dto.user_dto import UserQueryPageReq, UserQueryReq
from router import BaseController, unified_resp from router import BaseController, unified_resp
from service.user_service import UserService from service.user_service import UserService
@ -10,10 +14,10 @@ base_app = BaseController(base_service)
@router.get("/page") @router.get("/page")
@unified_resp @unified_resp
async def get_page(req:UserQueryPageReq=Query(...)): async def get_page(req:UserQueryPageReq=Query(...)) -> BasePageResp[User]:
return await base_service.get_by_page(req) return await base_service.get_by_page(req)
@router.get("/list") @router.get("/list")
@unified_resp @unified_resp
async def get_list(req:UserQueryReq=Query(...)): async def get_list(req:UserQueryReq=Query(...))->List[User]:
return await base_service.get_list(req) return await base_service.get_list(req)

@ -1,11 +1,10 @@
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional from typing import Union, Type, List, Any, TypeVar, Generic
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from core.global_context import current_session from core.global_context import current_session
from entity import DbBaseModel from entity import DbBaseModel
@ -22,6 +21,7 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额
""" """
T = TypeVar('T', bound=DbBaseModel) T = TypeVar('T', bound=DbBaseModel)
class BaseService(Generic[T]): class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型 model: Type[T] # 子类必须指定模型
@ -66,7 +66,7 @@ class BaseService(Generic[T]):
pass pass
@classmethod @classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp: async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -75,7 +75,7 @@ class BaseService(Generic[T]):
@classmethod @classmethod
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None)->BasePageResp: dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]:
if not query_params: if not query_params:
query_params = {} query_params = {}
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
@ -116,7 +116,7 @@ class BaseService(Generic[T]):
}) })
@classmethod @classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]: async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -150,7 +150,7 @@ class BaseService(Generic[T]):
return [item["id"] for item in exec_result.scalars().all()] return [item["id"] for item in exec_result.scalars().all()]
@classmethod @classmethod
async def save(cls, **kwargs)->T: async def save(cls, **kwargs) -> T:
sample_obj = cls.model(**kwargs) sample_obj = cls.model(**kwargs)
session = cls.get_db() session = cls.get_db()
session.add(sample_obj) session.add(sample_obj)
@ -158,7 +158,7 @@ class BaseService(Generic[T]):
return sample_obj return sample_obj
@classmethod @classmethod
async def insert_many(cls, data_list, batch_size=100)->None: async def insert_many(cls, data_list, batch_size=100) -> None:
async with cls.get_db() as session: async with cls.get_db() as session:
for d in data_list: for d in data_list:
if not d.get("id", None): if not d.get("id", None):
@ -168,27 +168,27 @@ class BaseService(Generic[T]):
session.add_all(data_list[i: i + batch_size]) session.add_all(data_list[i: i + batch_size])
@classmethod @classmethod
async def update_by_id(cls, pid, data)-> int: async def update_by_id(cls, pid, data) -> int:
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
session = cls.get_db() session = cls.get_db()
result = await session.execute(update_stmt) result = await session.execute(update_stmt)
return result.rowcount return result.rowcount
@classmethod @classmethod
async def update_many_by_id(cls, data_list)->None: async def update_many_by_id(cls, data_list) -> None:
async with cls.get_db() as session: async with cls.get_db() as session:
for data in data_list: for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt) await session.execute(stmt)
@classmethod @classmethod
async def get_by_id(cls, pid)->T: async def get_by_id(cls, pid) -> T:
stmt = cls.model.select().where(cls.model.id == pid) stmt = cls.model.select().where(cls.model.id == pid)
session = cls.get_db() session = cls.get_db()
return await session.scalar(stmt) return await session.scalar(stmt)
@classmethod @classmethod
async def get_by_ids(cls, pids, cols=None)->List[T]: async def get_by_ids(cls, pids, cols=None) -> List[T]:
if cols: if cols:
objs = cls.model.select(*cols) objs = cls.model.select(*cols)
else: else:
@ -199,14 +199,14 @@ class BaseService(Generic[T]):
return list(result.all()) return list(result.all())
@classmethod @classmethod
async def delete_by_id(cls, pid)-> int: async def delete_by_id(cls, pid) -> int:
del_stmt = cls.model.delete().where(cls.model.id == pid) del_stmt = cls.model.delete().where(cls.model.id == pid)
session = cls.get_db() session = cls.get_db()
exec_result = await session.execute(del_stmt) exec_result = await session.execute(del_stmt)
return exec_result.rowcount return exec_result.rowcount
@classmethod @classmethod
async def delete_by_ids(cls, pids)-> int: async def delete_by_ids(cls, pids) -> int:
session = cls.get_db() session = cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt) result = await session.execute(del_stmt)

Loading…
Cancel
Save