|
|
|
@ -29,7 +29,7 @@ class BaseService(Generic[T]):
|
|
|
|
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
|
|
|
|
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
|
|
|
|
if stmt is None:
|
|
|
|
if stmt is None:
|
|
|
|
if fields:
|
|
|
|
if fields:
|
|
|
|
stmt = cls.model.select(*fields)
|
|
|
|
stmt = cls.model.select(fields)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
stmt = cls.model.select()
|
|
|
|
stmt = cls.model.select()
|
|
|
|
for key, value in query_params.items():
|
|
|
|
for key, value in query_params.items():
|
|
|
|
@ -48,6 +48,9 @@ class BaseService(Generic[T]):
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[
|
|
|
|
def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[
|
|
|
|
BaseModel, List[BaseModel]]:
|
|
|
|
BaseModel, List[BaseModel]]:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
数据脱敏
|
|
|
|
|
|
|
|
"""
|
|
|
|
dto_list = []
|
|
|
|
dto_list = []
|
|
|
|
if not isinstance(entity_data, list):
|
|
|
|
if not isinstance(entity_data, list):
|
|
|
|
return dto(**entity_data.model_dump())
|
|
|
|
return dto(**entity_data.model_dump())
|
|
|
|
@ -59,23 +62,27 @@ class BaseService(Generic[T]):
|
|
|
|
return dto_list
|
|
|
|
return dto_list
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def check_base_permission(cls, daba: Any):
|
|
|
|
async def check_base_permission(cls, daba: Any):
|
|
|
|
# todo
|
|
|
|
# todo
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]:
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq], dto_model_class: Type[BaseModel] = None) -> \
|
|
|
|
|
|
|
|
BasePageResp[T]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
query_stmt = cls.get_query_stmt(query_params)
|
|
|
|
fields = None
|
|
|
|
return await cls.auto_page(query_stmt, query_params)
|
|
|
|
if dto_model_class is not None:
|
|
|
|
|
|
|
|
fields = [getattr(cls.model, key) for key in dto_model_class.model_fields.keys()]
|
|
|
|
|
|
|
|
query_stmt = cls.get_query_stmt(query_params, fields=fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return await cls.auto_page(query_stmt, query_params, dto_model_class)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> \
|
|
|
|
dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> BasePageResp[T]:
|
|
|
|
BasePageResp[T]:
|
|
|
|
|
|
|
|
if not query_params:
|
|
|
|
if not query_params:
|
|
|
|
query_params = {}
|
|
|
|
query_params = {}
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
@ -84,87 +91,45 @@ class BaseService(Generic[T]):
|
|
|
|
page_size = query_params.get("page_size", 12)
|
|
|
|
page_size = query_params.get("page_size", 12)
|
|
|
|
sort = query_params.get("sort", "desc")
|
|
|
|
sort = query_params.get("sort", "desc")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
data_count = None
|
|
|
|
|
|
|
|
if data_count == 0:
|
|
|
|
|
|
|
|
return BasePageResp(**{
|
|
|
|
|
|
|
|
"page_number": page_number,
|
|
|
|
|
|
|
|
"page_size": page_size,
|
|
|
|
|
|
|
|
"count": data_count,
|
|
|
|
|
|
|
|
"sort": sort,
|
|
|
|
|
|
|
|
"orderby": orderby,
|
|
|
|
|
|
|
|
"data": [],
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
if sort == "desc":
|
|
|
|
if sort == "desc":
|
|
|
|
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).desc())
|
|
|
|
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).desc())
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).asc())
|
|
|
|
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).asc())
|
|
|
|
query_page_result = await paginate(session,
|
|
|
|
|
|
|
|
query_stmt,
|
|
|
|
query_page_result = await paginate(session, query_stmt, Params(page=page_number, size=page_size))
|
|
|
|
Params(page=page_number, size=page_size))
|
|
|
|
|
|
|
|
result = query_page_result.items
|
|
|
|
result = query_page_result.items
|
|
|
|
if dto_model_class is not None:
|
|
|
|
if dto_model_class is not None:
|
|
|
|
result = [dto_model_class(**item) for item in result]
|
|
|
|
# 使用 row._mapping 将 Row 对象转换为可解包的字典
|
|
|
|
return BasePageResp(**{
|
|
|
|
result = [dto_model_class(**dict(row._mapping)) for row in result]
|
|
|
|
"page_number": page_number,
|
|
|
|
return BasePageResp(
|
|
|
|
"page_size": page_size,
|
|
|
|
**{"page_number": page_number, "page_size": page_size, "page_count": query_page_result.pages,
|
|
|
|
"page_count": query_page_result.pages,
|
|
|
|
"count": query_page_result.total, "sort": sort, "orderby": orderby, "data": result, })
|
|
|
|
"count": query_page_result.total,
|
|
|
|
|
|
|
|
"sort": sort,
|
|
|
|
|
|
|
|
"orderby": orderby,
|
|
|
|
|
|
|
|
"data": result,
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> List[
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq], dto_model_class: Type[BaseModel] = None, *,
|
|
|
|
T]:
|
|
|
|
session: Optional[AsyncSession] = None) -> List[T] | List[BaseModel]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
"""
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params: 参数字典 or 参数请求模型
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
dto_model_class: 输出类型
|
|
|
|
sort = query_params.get("sort", "desc")
|
|
|
|
session: 数据库会话---支持传递以便于事务管控
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
获取数据集合
|
|
|
|
query_stmt = cls.get_query_stmt(query_params)
|
|
|
|
"""
|
|
|
|
field = getattr(cls.model, orderby)
|
|
|
|
query_stmt = cls.build_query(query_params, dto_model_class=dto_model_class)
|
|
|
|
if sort == "desc":
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(field.desc())
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(field.asc())
|
|
|
|
|
|
|
|
if query_params.get("limit", None) is not None:
|
|
|
|
|
|
|
|
query_stmt = query_stmt.limit(query_params.get("limit"))
|
|
|
|
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
return list(exec_result.scalars().all())
|
|
|
|
return cls.parse_result(exec_result, dto_model_class=dto_model_class)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
async def get_list_json(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
|
|
|
|
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None):
|
|
|
|
List[
|
|
|
|
query_stmt = cls.build_query(query_params, fields=[cls.model.id])
|
|
|
|
T]:
|
|
|
|
exec_result = await session.scalars(query_stmt)
|
|
|
|
resp_list = await cls.get_list(query_params, session=session)
|
|
|
|
return list(exec_result)
|
|
|
|
|
|
|
|
|
|
|
|
return [i.model_dump() for i in resp_list]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
|
|
|
|
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
|
|
|
|
|
|
|
|
List[str]:
|
|
|
|
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
|
|
|
|
sort = query_params.get("sort", "desc")
|
|
|
|
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
|
|
|
|
query_stmt = cls.model.select(cls.model.id)
|
|
|
|
|
|
|
|
query_stmt = cls.get_query_stmt(query_params, query_stmt)
|
|
|
|
|
|
|
|
if sort == "desc":
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
|
|
|
|
return [item["id"] for item in exec_result.scalars().all()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T:
|
|
|
|
async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T:
|
|
|
|
|
|
|
|
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
session.add(sample_obj)
|
|
|
|
session.add(sample_obj)
|
|
|
|
await session.flush()
|
|
|
|
await session.flush()
|
|
|
|
@ -206,7 +171,6 @@ class BaseService(Generic[T]):
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T:
|
|
|
|
async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T:
|
|
|
|
|
|
|
|
|
|
|
|
stmt = cls.model.select().where(cls.model.id == pid)
|
|
|
|
stmt = cls.model.select().where(cls.model.id == pid)
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
@ -222,15 +186,12 @@ class BaseService(Generic[T]):
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]:
|
|
|
|
async def get_by_ids(cls, pids, dto_model_class: Type[BaseModel] = None, *,
|
|
|
|
|
|
|
|
session: Optional[AsyncSession] = None) -> List[T]:
|
|
|
|
if cols:
|
|
|
|
stmt = cls.build_query({}, dto_model_class=dto_model_class)
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
stmt = stmt.where(cls.model.id.in_(pids))
|
|
|
|
else:
|
|
|
|
exec_result = await session.execute(stmt)
|
|
|
|
objs = cls.model.select()
|
|
|
|
return cls.parse_result(exec_result, dto_model_class=dto_model_class)
|
|
|
|
stmt = objs.where(cls.model.id.in_(pids))
|
|
|
|
|
|
|
|
result = await session.scalars(stmt)
|
|
|
|
|
|
|
|
return list(result.all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@with_db_session()
|
|
|
|
@with_db_session()
|
|
|
|
@ -271,3 +232,52 @@ class BaseService(Generic[T]):
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def is_exist(cls, query_params: dict = None):
|
|
|
|
async def is_exist(cls, query_params: dict = None):
|
|
|
|
return await cls.get_data_count(query_params) > 0
|
|
|
|
return await cls.get_data_count(query_params) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def build_query(cls, query_params: Union[dict, BaseQueryReq], *, dto_model_class: Type[BaseModel] = None,
|
|
|
|
|
|
|
|
fields: List = None):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
可选dto_model_class、fields
|
|
|
|
|
|
|
|
优先级dto_model_class>fields
|
|
|
|
|
|
|
|
如果传递了dto_model_class则无需传递fields,fields会被dto_model_class覆盖
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
|
|
|
|
sort = query_params.get("sort", "desc").lower()
|
|
|
|
|
|
|
|
orderby = query_params.get("orderby", "created_time").lower()
|
|
|
|
|
|
|
|
if dto_model_class is not None:
|
|
|
|
|
|
|
|
# 安全映射:只获取 DTO 中存在且 Model 中也有的字段
|
|
|
|
|
|
|
|
# 避免 DTO 包含计算字段时 getattr 报错
|
|
|
|
|
|
|
|
db_columns = []
|
|
|
|
|
|
|
|
for key in dto_model_class.model_fields.keys():
|
|
|
|
|
|
|
|
if hasattr(cls.model, key):
|
|
|
|
|
|
|
|
db_columns.append(getattr(cls.model, key))
|
|
|
|
|
|
|
|
fields = db_columns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_stmt = cls.get_query_stmt(query_params, fields=fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(cls.model, orderby):
|
|
|
|
|
|
|
|
order_field = getattr(cls.model, orderby)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# 如果传入的 orderby 字段不存在,回退到默认排序,防止报错
|
|
|
|
|
|
|
|
order_field = cls.model.created_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 根据xxx字段排序
|
|
|
|
|
|
|
|
if sort == "desc":
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(order_field.desc())
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(order_field.asc())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if query_params.get("limit", None) is not None:
|
|
|
|
|
|
|
|
query_stmt = query_stmt.limit(query_params.get("limit"))
|
|
|
|
|
|
|
|
return query_stmt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def parse_result(cls, exec_result, dto_model_class: Type[BaseModel] = None):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
将数据库执行结果解析为 DTO 列表或 实体列表
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if dto_model_class is not None:
|
|
|
|
|
|
|
|
return [dto_model_class(**dict(row._mapping)) for row in exec_result.all()]
|
|
|
|
|
|
|
|
return list(exec_result.scalars().all())
|
|
|
|
|