perf: 优化unified_resp装饰器(统一接口响应),确保 api文档展示 的响应类型与真实的一致

main
chenzhirong 4 months ago
parent 9a211f9484
commit 521bf98c56

@ -1,10 +1,15 @@
from collections import namedtuple
from typing import TypeVar, Generic, Optional
from beartype.claw import beartype_this_package
from pydantic import BaseModel
beartype_this_package()
HttpCode = namedtuple('HttpResp', ['code', 'msg'])
T = TypeVar("T")
class HttpResp:
"""HTTP响应结果
"""
@ -24,4 +29,10 @@ class HttpResp:
REQUEST_404_ERROR = HttpCode(404, '请求接口不存在')
DATA_ALREADY_EXISTS = HttpCode(409, '数据已存在')
SYSTEM_ERROR = HttpCode(500, '系统错误')
SYSTEM_TIMEOUT_ERROR = HttpCode(504, '请求超时')
SYSTEM_TIMEOUT_ERROR = HttpCode(504, '请求超时')
class ApiResponse(BaseModel, Generic[T]):
code: int = HttpResp.SUCCESS.code
message: str = HttpResp.SUCCESS.msg
data: Optional[T] = None

@ -1,8 +1,9 @@
from typing import Optional, List, Any
from typing import Optional, List, Any, Generic, TypeVar
from typing import Union
from pydantic import BaseModel, Field
T = TypeVar('T')
class BaseTabelDto(BaseModel):
id: Optional[str] = None
@ -27,14 +28,14 @@ class BaseRenameReq(BaseModel):
name: str
class BasePageResp(BaseModel):
class BasePageResp(BaseModel, Generic[T]):
page_number: Optional[int]
page_size: Optional[int]
page_count: Optional[int]
sort: Optional[str]
orderby: Optional[str]
count: Optional[int]
data: Optional[List[Any]]
data: Optional[List[T]]
class Config:
arbitrary_types_allowed = True

@ -1,8 +1,8 @@
import inspect
import logging
from functools import wraps
from typing import Union, Type, Callable, TypeVar
from datetime import datetime
from functools import wraps
from typing import Union, Type, Callable, TypeVar, get_type_hints
import pytz
from fastapi.encoders import jsonable_encoder
@ -10,15 +10,24 @@ from pydantic import BaseModel
from starlette.responses import JSONResponse
from config import get_settings
from entity.dto import HttpResp
from entity.dto import HttpResp, ApiResponse
from exceptions.base import AppException
from utils import get_uuid
RT = TypeVar('RT') # 返回类型
def unified_resp(func: Callable[..., RT]) -> Callable[..., RT]:
"""统一响应格式
接口正常返回时,统一响应结果格式
"""
# 获取原始函数的返回类型注解
hints = get_type_hints(func)
return_type = hints.get('return', None)
# 修改函数的返回类型注解
if return_type:
func.__annotations__['return'] = ApiResponse[return_type]
@wraps(func)
async def wrapper(*args, **kwargs) -> RT:

@ -1,5 +1,9 @@
from typing import List
from fastapi import APIRouter, Query
from entity.db_models import User
from entity.dto.base import BasePageResp
from entity.dto.user_dto import UserQueryPageReq, UserQueryReq
from router import BaseController, unified_resp
from service.user_service import UserService
@ -10,10 +14,10 @@ base_app = BaseController(base_service)
@router.get("/page")
@unified_resp
async def get_page(req:UserQueryPageReq=Query(...)):
async def get_page(req:UserQueryPageReq=Query(...)) -> BasePageResp[User]:
return await base_service.get_by_page(req)
@router.get("/list")
@unified_resp
async def get_list(req:UserQueryReq=Query(...)):
async def get_list(req:UserQueryReq=Query(...))->List[User]:
return await base_service.get_list(req)

@ -1,11 +1,10 @@
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
from typing import Union, Type, List, Any, TypeVar, Generic
from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from core.global_context import current_session
from entity import DbBaseModel
@ -22,6 +21,7 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额
"""
T = TypeVar('T', bound=DbBaseModel)
class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型
@ -66,7 +66,7 @@ class BaseService(Generic[T]):
pass
@classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp:
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -75,7 +75,7 @@ class BaseService(Generic[T]):
@classmethod
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None)->BasePageResp:
dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]:
if not query_params:
query_params = {}
if not isinstance(query_params, dict):
@ -116,7 +116,7 @@ class BaseService(Generic[T]):
})
@classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]:
async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -150,7 +150,7 @@ class BaseService(Generic[T]):
return [item["id"] for item in exec_result.scalars().all()]
@classmethod
async def save(cls, **kwargs)->T:
async def save(cls, **kwargs) -> T:
sample_obj = cls.model(**kwargs)
session = cls.get_db()
session.add(sample_obj)
@ -158,7 +158,7 @@ class BaseService(Generic[T]):
return sample_obj
@classmethod
async def insert_many(cls, data_list, batch_size=100)->None:
async def insert_many(cls, data_list, batch_size=100) -> None:
async with cls.get_db() as session:
for d in data_list:
if not d.get("id", None):
@ -168,27 +168,27 @@ class BaseService(Generic[T]):
session.add_all(data_list[i: i + batch_size])
@classmethod
async def update_by_id(cls, pid, data)-> int:
async def update_by_id(cls, pid, data) -> int:
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
session = cls.get_db()
result = await session.execute(update_stmt)
return result.rowcount
@classmethod
async def update_many_by_id(cls, data_list)->None:
async def update_many_by_id(cls, data_list) -> None:
async with cls.get_db() as session:
for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt)
@classmethod
async def get_by_id(cls, pid)->T:
async def get_by_id(cls, pid) -> T:
stmt = cls.model.select().where(cls.model.id == pid)
session = cls.get_db()
return await session.scalar(stmt)
@classmethod
async def get_by_ids(cls, pids, cols=None)->List[T]:
async def get_by_ids(cls, pids, cols=None) -> List[T]:
if cols:
objs = cls.model.select(*cols)
else:
@ -199,14 +199,14 @@ class BaseService(Generic[T]):
return list(result.all())
@classmethod
async def delete_by_id(cls, pid)-> int:
async def delete_by_id(cls, pid) -> int:
del_stmt = cls.model.delete().where(cls.model.id == pid)
session = cls.get_db()
exec_result = await session.execute(del_stmt)
return exec_result.rowcount
@classmethod
async def delete_by_ids(cls, pids)-> int:
async def delete_by_ids(cls, pids) -> int:
session = cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt)

Loading…
Cancel
Save