init: 完成 base_service.py,为后续提升开发速度

main
chenzhirong 5 months ago
parent c8c7bbfea4
commit 61eea6dc63

@ -25,28 +25,15 @@ class DbBaseModel(SQLModel, table=False):
fields = cls fields = cls
return Select(fields) return Select(fields)
@classmethod @classmethod
def delete(cls): def delete(cls):
return Delete(cls) return Delete(cls)
@classmethod
def delete_by_id(cls, id: str):
return Delete(cls).where(cls.id==id)
@classmethod
def delete_by_ids(cls, ids: list[str] | str):
if isinstance(ids, str):
ids = [ids]
return Delete(cls).where(cls.id.in_(ids))
@classmethod @classmethod
def update(cls): def update(cls):
return Update(cls) return Update(cls)
@classmethod
def update_by_id(cls, id: str,update_dict: dict):
update_dict.pop("id",None)
return Update(cls).where(cls.id == id).values(**update_dict)
@classmethod @classmethod
def update_by_ids(cls, ids: list[str],update_dict: dict): def update_by_ids(cls, ids: list[str],update_dict: dict):

@ -1,19 +1,25 @@
from typing import Union, Type, List, Any, TypeVar, Generic from typing import Union, Type, List, Any
from fastapi_pagination import Params, Page from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Select, select, func from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from core.global_context import current_session from core.global_context import current_session
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
from utils import get_uuid, current_timestamp from utils import get_uuid, current_timestamp
"""
session.execute 执行任意数据库操作语句返回结果需要额外处理获取 数据形式Row对象类似元组
-- exec_result.scalars().all(): 从session.execute的执行结果中获取全部的数据
session.scalars: 只适合单模型查询不适合指定列或连表查询返回结果需要处理 数据形式直接返回模型实例如User对象
-- 处理exec_result.all(): 从session.scalars的执行结果中获取全部的数据
session.scalar: 直接明确获取一条数据可以直接返回无需额外处理
"""
class BaseService: class BaseService:
model=None # 子类必须指定模型 model = None # 子类必须指定模型
@classmethod @classmethod
def get_db(cls) -> AsyncSession: def get_db(cls) -> AsyncSession:
@ -23,13 +29,24 @@ class BaseService:
raise RuntimeError("No database session in context. " raise RuntimeError("No database session in context. "
"Make sure to use this service within a request context.") "Make sure to use this service within a request context.")
return session return session
@classmethod @classmethod
def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[ def get_query_stmt(cls, query_params, sessions=None,*,fields:list=None):
if sessions is None:
if fields:
sessions = cls.model.select(*fields)
else:
sessions = cls.model.select()
for key, value in query_params.items():
if hasattr(cls.model, key) and value is not None:
field = getattr(cls.model, key)
sessions = sessions.where(field == value)
return sessions
@classmethod
def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[
BaseModel, List[BaseModel]]: BaseModel, List[BaseModel]]:
dto_list = [] dto_list = []
if not isinstance(entity_data, list): if not isinstance(entity_data, list):
return dto(**entity_data.to_dict()) return dto(**entity_data.model_dump())
for entity in entity_data: for entity in entity_data:
temp = entity temp = entity
if not isinstance(entity, dict): if not isinstance(entity, dict):
@ -37,22 +54,17 @@ class BaseService:
dto_list.append(dto(**temp)) dto_list.append(dto(**temp))
return dto_list return dto_list
@classmethod @classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]): async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
query_entity = cls.get_query_entity(query_params) query_stmt = cls.get_query_stmt(query_params)
return await cls.auto_page(query_entity, query_params) return await cls.auto_page(query_stmt, query_params)
# @classmethod
# def count_query(cls,query: Select) -> Select:
# # type: ignore
# return select(func.count("*")).select_from(count_subquery)
@classmethod @classmethod
async def auto_page(cls, query_entity, query_params: Union[dict, BasePageQueryReq] = None, async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None): dto_model_class: Type[BaseModel] = None):
if not query_params: if not query_params:
query_params = {} query_params = {}
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
@ -61,11 +73,8 @@ class BaseService:
page_size = query_params.get("page_size", 12) page_size = query_params.get("page_size", 12)
desc = query_params.get("desc", "desc") desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time") orderby = query_params.get("orderby", "created_time")
# data_count = await sessions.count()
session = cls.get_db() session = cls.get_db()
data_count = None
# data_count = session.scalar(cls.count_query(query_entity))
data_count =None
if data_count == 0: if data_count == 0:
return BasePageResp(**{ return BasePageResp(**{
"page_number": page_number, "page_number": page_number,
@ -76,17 +85,12 @@ class BaseService:
"data": [], "data": [],
}) })
if desc == "desc": if desc == "desc":
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).desc())
query_entity = query_entity.order_by(getattr(cls.model,orderby).desc())
else: else:
query_entity = query_entity.order_by(getattr(cls.model,orderby).asc()) query_stmt = query_stmt.order_by(getattr(cls.model, orderby).asc())
query_page_result=await paginate(session, query_page_result = await paginate(session,
query_entity, query_stmt,
Params(page=page_number,size=page_size)) Params(page=page_number, size=page_size))
# query_entity = query_entity.offset((page_number - 1) * page_size).limit(page_size)
# query_exec_result = await session.execute(query_entity)
# result = query_exec_result.scalars().all()
# return query_page_result
result = query_page_result.items result = query_page_result.items
if dto_model_class is not None: if dto_model_class is not None:
result = [dto_model_class(**item) for item in result] result = [dto_model_class(**item) for item in result]
@ -101,161 +105,112 @@ class BaseService:
}) })
@classmethod @classmethod
def get_list(cls, query_params: Union[dict, BaseQueryReq]): async def get_list(cls, query_params: Union[dict, BaseQueryReq]):
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
desc = query_params.get("desc", "desc") desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time") orderby = query_params.get("orderby", "created_time")
sessions = cls.get_query_entity(query_params) query_stmt = cls.get_query_stmt(query_params)
field = getattr(cls.model, orderby)
if desc == "desc": if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) query_stmt = query_stmt.order_by(field.desc())
else: else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) query_stmt = query_stmt.order_by(field.asc())
session = cls.get_db()
return sessions.scalars().all() exec_result = await session.execute(query_stmt)
return exec_result.scalars().all()
@classmethod @classmethod
def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]: async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
desc = query_params.get("desc", "desc") desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time") orderby = query_params.get("orderby", "created_time")
sessions = cls.model.select(cls.model.id) query_stmt = cls.model.select(cls.model.id)
sessions = cls.get_query_entity(query_params, sessions) query_stmt = cls.get_query_stmt(query_params, query_stmt)
if desc == "desc": if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc())
else: else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc())
return [item["id"] for item in sessions.scalars().all()] session = cls.get_db()
exec_result = await session.execute(query_stmt)
return [item["id"] for item in exec_result.scalars().all()]
@classmethod @classmethod
def save(cls, **kwargs): async def save(cls, **kwargs):
"""Save a new record to database. sample_obj = cls.model(**kwargs)
session = cls.get_db()
This method creates a new record in the database with the provided field values, session.add(sample_obj)
forcing an insert operation rather than an update. await session.flush()
return sample_obj
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
# todo
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj > 0
@classmethod @classmethod
def insert_many(cls, data_list, batch_size=100): async def insert_many(cls, data_list, batch_size=100):
"""Insert multiple records in batches. async with cls.get_db() as session:
This method efficiently inserts multiple records into the database using batch processing.
It automatically sets creation timestamps for all records.
Args:
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
with DB.atomic():
for d in data_list: for d in data_list:
if not d.get("id", None): if not d.get("id", None):
d["id"] = get_uuid() d["id"] = get_uuid()
d["created_time"] = current_timestamp()
# d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size): for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i: i + batch_size]).execute() session.add_all(data_list[i: i + batch_size])
@classmethod @classmethod
def update_by_id(cls, pid, data): async def update_by_id(cls, pid, data):
# Update a single record by ID update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
# Args: session = cls.get_db()
# pid: Record ID result = await session.execute(update_stmt)
# data: Updated field values return result.rowcount
# Returns:
# Number of records updated
data["updated_time"] = current_timestamp()
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod @classmethod
def update_many_by_id(cls, data_list): async def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs. async with cls.get_db() as session:
This method updates multiple records in the database, identified by their IDs.
It automatically updates the updated_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list: for data in data_list:
data["updated_time"] = current_timestamp() stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
# data["update_date"] = datetime_format(datetime.now()) await session.execute(stmt)
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod @classmethod
def get_by_id(cls, pid): async def get_by_id(cls, pid):
# Get a record by ID stmt = cls.model.select(cls.model.id == pid)
# Args: session = cls.get_db()
# pid: Record ID return await session.scalar(stmt)
# Returns:
# Tuple of (success, record)
try:
obj = cls.model.get_or_none(cls.model.id == pid)
if obj:
return True, obj
except Exception:
pass
return False, None
@classmethod @classmethod
def get_by_ids(cls, pids, cols=None): async def get_by_ids(cls, pids, cols=None):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
if cols: if cols:
objs = cls.model.select(*cols) objs = cls.model.select(*cols)
else: else:
objs = cls.model.select() objs = cls.model.select()
return objs.where(cls.model.id.in_(pids)) stmt = objs.where(cls.model.id.in_(pids))
session = cls.get_db()
result = await session.scalars(stmt)
return result.all()
@classmethod @classmethod
def delete_by_id(cls, pid): async def delete_by_id(cls, pid):
... del_stmt = cls.model.delete().where(cls.model.id == pid)
@classmethod session = cls.get_db()
def delete_by_ids(cls, pids): exec_result = await session.execute(del_stmt)
... return exec_result.rowcount
@classmethod @classmethod
def get_query_entity(cls, query_params, sessions=None): async def delete_by_ids(cls, pids):
if sessions is None: session = cls.get_db()
sessions = cls.model.select() del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
for key, value in query_params.items(): result = await session.execute(del_stmt)
if hasattr(cls.model, key): return result.rowcount
field = getattr(cls.model, key)
sessions = sessions.where(field == value)
return sessions
@classmethod @classmethod
def get_data_count(cls, query_params: dict = None): async def get_data_count(cls, query_params: dict = None)->int:
if not query_params: if not query_params:
raise Exception("参数为空") raise Exception("参数为空")
sessions = cls.get_query_entity(query_params) stmt = cls.get_query_stmt(query_params,fields=[func.count(cls.model.id)])
return sessions.count() # stmt = cls.get_query_stmt(query_params)
session = cls.get_db()
return await session.scalar(stmt)
@classmethod
def is_exist(cls, query_params: dict = None):
return cls.get_data_count(query_params) > 0
@classmethod
async def is_exist(cls, query_params: dict = None):
return await cls.get_data_count(query_params) > 0

Loading…
Cancel
Save