|
|
|
|
@ -1,16 +1,19 @@
|
|
|
|
|
from typing import Union, Type, List, Any, TypeVar, Generic
|
|
|
|
|
|
|
|
|
|
from fastapi_pagination import Params, Page
|
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from sqlalchemy import Select, select, func
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
from core.global_context import current_session
|
|
|
|
|
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
|
|
|
|
|
from utils import get_uuid, current_timestamp
|
|
|
|
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseService(Generic[T]):
|
|
|
|
|
model: Type[T] # 子类必须指定模型
|
|
|
|
|
class BaseService:
|
|
|
|
|
model=None # 子类必须指定模型
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_db(cls) -> AsyncSession:
|
|
|
|
|
@ -21,15 +24,6 @@ class BaseService(Generic[T]):
|
|
|
|
|
"Make sure to use this service within a request context.")
|
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def create(cls, **kwargs) -> T:
|
|
|
|
|
"""通用创建方法"""
|
|
|
|
|
obj = cls.model(**kwargs)
|
|
|
|
|
db = cls.get_db()
|
|
|
|
|
db.add(obj)
|
|
|
|
|
await db.flush()
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[
|
|
|
|
|
BaseModel, List[BaseModel]]:
|
|
|
|
|
@ -43,100 +37,21 @@ class BaseService(Generic[T]):
|
|
|
|
|
dto_list.append(dto(**temp))
|
|
|
|
|
return dto_list
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
|
|
|
|
"""Execute a database query with optional column selection and ordering.
|
|
|
|
|
|
|
|
|
|
This method provides a flexible way to query the database with various filters
|
|
|
|
|
and sorting options. It supports column selection, sort order control, and
|
|
|
|
|
additional filter conditions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cols (list, optional): List of column names to select. If None, selects all columns.
|
|
|
|
|
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
|
|
|
|
order_by (str, optional): Column name to sort results by.
|
|
|
|
|
**kwargs: Additional filter conditions passed as keyword arguments.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
peewee.ModelSelect: A query result containing matching records.
|
|
|
|
|
"""
|
|
|
|
|
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_all(cls, cols=None, reverse=None, order_by=None):
|
|
|
|
|
"""Retrieve all records from the database with optional column selection and ordering.
|
|
|
|
|
|
|
|
|
|
This method fetches all records from the model's table with support for
|
|
|
|
|
column selection and result ordering. If no order_by is specified and reverse
|
|
|
|
|
is True, it defaults to ordering by created_time.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cols (list, optional): List of column names to select. If None, selects all columns.
|
|
|
|
|
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
|
|
|
|
order_by (str, optional): Column name to sort results by. Defaults to 'created_time' if reverse is specified.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
peewee.ModelSelect: A query containing all matching records.
|
|
|
|
|
"""
|
|
|
|
|
if cols:
|
|
|
|
|
query_records = cls.model.select(*cols)
|
|
|
|
|
else:
|
|
|
|
|
query_records = cls.model.select()
|
|
|
|
|
if reverse is not None:
|
|
|
|
|
if not order_by or not hasattr(cls, order_by):
|
|
|
|
|
order_by = "created_time"
|
|
|
|
|
if reverse is True:
|
|
|
|
|
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
|
|
|
|
elif reverse is False:
|
|
|
|
|
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
|
|
|
|
return query_records
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get(cls, **kwargs):
|
|
|
|
|
"""Get a single record matching the given criteria.
|
|
|
|
|
|
|
|
|
|
This method retrieves a single record from the database that matches
|
|
|
|
|
the specified filter conditions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
**kwargs: Filter conditions as keyword arguments.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Model instance: Single matching record.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
peewee.DoesNotExist: If no matching record is found.
|
|
|
|
|
"""
|
|
|
|
|
return cls.model.get(**kwargs)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_or_none(cls, **kwargs):
|
|
|
|
|
"""Get a single record or None if not found.
|
|
|
|
|
|
|
|
|
|
This method attempts to retrieve a single record matching the given criteria,
|
|
|
|
|
returning None if no match is found instead of raising an exception.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
**kwargs: Filter conditions as keyword arguments.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Model instance or None: Matching record if found, None otherwise.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
return cls.model.get(**kwargs)
|
|
|
|
|
except peewee.DoesNotExist:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
|
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
|
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
|
sessions = cls.get_query_session(query_params)
|
|
|
|
|
return cls.auto_page(sessions, query_params)
|
|
|
|
|
query_entity = cls.get_query_entity(query_params)
|
|
|
|
|
return await cls.auto_page(query_entity, query_params)
|
|
|
|
|
# @classmethod
|
|
|
|
|
# def count_query(cls,query: Select) -> Select:
|
|
|
|
|
# # type: ignore
|
|
|
|
|
# return select(func.count("*")).select_from(count_subquery)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def auto_page(cls, sessions, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
|
async def auto_page(cls, query_entity, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
|
dto_model_class: Type[BaseModel] = None):
|
|
|
|
|
if not query_params:
|
|
|
|
|
query_params = {}
|
|
|
|
|
@ -146,7 +61,11 @@ class BaseService(Generic[T]):
|
|
|
|
|
page_size = query_params.get("page_size", 12)
|
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
|
data_count = sessions.count()
|
|
|
|
|
# data_count = await sessions.count()
|
|
|
|
|
session = cls.get_db()
|
|
|
|
|
|
|
|
|
|
# data_count = session.scalar(cls.count_query(query_entity))
|
|
|
|
|
data_count =None
|
|
|
|
|
if data_count == 0:
|
|
|
|
|
return BasePageResp(**{
|
|
|
|
|
"page_number": page_number,
|
|
|
|
|
@ -157,18 +76,25 @@ class BaseService(Generic[T]):
|
|
|
|
|
"data": [],
|
|
|
|
|
})
|
|
|
|
|
if desc == "desc":
|
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
|
|
|
|
|
|
query_entity = query_entity.order_by(getattr(cls.model,orderby).desc())
|
|
|
|
|
else:
|
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
|
sessions = sessions.paginate(int(page_number), int(page_size))
|
|
|
|
|
datas = list(sessions.dicts())
|
|
|
|
|
result = datas
|
|
|
|
|
query_entity = query_entity.order_by(getattr(cls.model,orderby).asc())
|
|
|
|
|
query_page_result=await paginate(session,
|
|
|
|
|
query_entity,
|
|
|
|
|
Params(page=page_number,size=page_size))
|
|
|
|
|
# query_entity = query_entity.offset((page_number - 1) * page_size).limit(page_size)
|
|
|
|
|
# query_exec_result = await session.execute(query_entity)
|
|
|
|
|
# result = query_exec_result.scalars().all()
|
|
|
|
|
# return query_page_result
|
|
|
|
|
result = query_page_result.items
|
|
|
|
|
if dto_model_class is not None:
|
|
|
|
|
result = [dto_model_class(**item) for item in datas]
|
|
|
|
|
result = [dto_model_class(**item) for item in result]
|
|
|
|
|
return BasePageResp(**{
|
|
|
|
|
"page_number": page_number,
|
|
|
|
|
"page_size": page_size,
|
|
|
|
|
"count": data_count,
|
|
|
|
|
"page_count": query_page_result.pages,
|
|
|
|
|
"count": query_page_result.total,
|
|
|
|
|
"desc": desc,
|
|
|
|
|
"orderby": orderby,
|
|
|
|
|
"data": result,
|
|
|
|
|
@ -181,13 +107,13 @@ class BaseService(Generic[T]):
|
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
|
sessions = cls.get_query_session(query_params)
|
|
|
|
|
sessions = cls.get_query_entity(query_params)
|
|
|
|
|
if desc == "desc":
|
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
|
else:
|
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
|
|
|
|
|
|
return list(sessions.dicts())
|
|
|
|
|
return sessions.scalars().all()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
|
|
|
|
|
@ -197,12 +123,12 @@ class BaseService(Generic[T]):
|
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
|
sessions = cls.model.select(cls.model.id)
|
|
|
|
|
sessions = cls.get_query_session(query_params, sessions)
|
|
|
|
|
sessions = cls.get_query_entity(query_params, sessions)
|
|
|
|
|
if desc == "desc":
|
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
|
else:
|
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
|
return [item["id"] for item in list(sessions.dicts())]
|
|
|
|
|
return [item["id"] for item in sessions.scalars().all()]
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def save(cls, **kwargs):
|
|
|
|
|
@ -217,32 +143,10 @@ class BaseService(Generic[T]):
|
|
|
|
|
Returns:
|
|
|
|
|
Model instance: The created record object.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# todo
|
|
|
|
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
|
|
|
|
return sample_obj > 0
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def insert(cls, **kwargs):
|
|
|
|
|
"""Insert a new record with automatic ID and timestamps.
|
|
|
|
|
|
|
|
|
|
This method creates a new record with automatically generated ID and timestamp fields.
|
|
|
|
|
It handles the creation of created_time, create_date, updated_time, and update_date fields.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
**kwargs: Record field values as keyword arguments.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Model instance: The newly created record object.
|
|
|
|
|
"""
|
|
|
|
|
if "id" not in kwargs:
|
|
|
|
|
kwargs["id"] = get_uuid()
|
|
|
|
|
|
|
|
|
|
kwargs["created_time"] = current_timestamp()
|
|
|
|
|
# kwargs["create_date"] = datetime_format(datetime.now())
|
|
|
|
|
kwargs["updated_time"] = current_timestamp()
|
|
|
|
|
# kwargs["update_date"] = datetime_format(datetime.now())
|
|
|
|
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
|
|
|
|
return sample_obj > 0
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def insert_many(cls, data_list, batch_size=100):
|
|
|
|
|
@ -266,6 +170,18 @@ class BaseService(Generic[T]):
|
|
|
|
|
for i in range(0, len(data_list), batch_size):
|
|
|
|
|
cls.model.insert_many(data_list[i: i + batch_size]).execute()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def update_by_id(cls, pid, data):
|
|
|
|
|
# Update a single record by ID
|
|
|
|
|
# Args:
|
|
|
|
|
# pid: Record ID
|
|
|
|
|
# data: Updated field values
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records updated
|
|
|
|
|
data["updated_time"] = current_timestamp()
|
|
|
|
|
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
|
|
|
|
return num
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def update_many_by_id(cls, data_list):
|
|
|
|
|
"""Update multiple records by their IDs.
|
|
|
|
|
@ -284,18 +200,6 @@ class BaseService(Generic[T]):
|
|
|
|
|
# data["update_date"] = datetime_format(datetime.now())
|
|
|
|
|
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def updated_by_id(cls, pid, data):
|
|
|
|
|
# Update a single record by ID
|
|
|
|
|
# Args:
|
|
|
|
|
# pid: Record ID
|
|
|
|
|
# data: Updated field values
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records updated
|
|
|
|
|
data["updated_time"] = current_timestamp()
|
|
|
|
|
# data["update_date"] = datetime_format(datetime.now())
|
|
|
|
|
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
|
|
|
|
return num > 0
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_by_id(cls, pid):
|
|
|
|
|
@ -326,101 +230,16 @@ class BaseService(Generic[T]):
|
|
|
|
|
objs = cls.model.select()
|
|
|
|
|
return objs.where(cls.model.id.in_(pids))
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_last_by_create_time(cls):
|
|
|
|
|
# Get multiple records by their IDs
|
|
|
|
|
# Args:
|
|
|
|
|
# pids: List of record IDs
|
|
|
|
|
# cols: List of columns to select
|
|
|
|
|
# Returns:
|
|
|
|
|
# Query of matching records
|
|
|
|
|
latest = cls.model.select().order_by(cls.model.created_time.desc()).first()
|
|
|
|
|
return latest
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def delete_by_id(cls, pid):
|
|
|
|
|
# Delete a record by ID
|
|
|
|
|
# Args:
|
|
|
|
|
# pid: Record ID
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records deleted
|
|
|
|
|
return cls.model.delete().where(cls.model.id == pid).execute()
|
|
|
|
|
|
|
|
|
|
...
|
|
|
|
|
@classmethod
|
|
|
|
|
def delete_by_ids(cls, pids):
|
|
|
|
|
# Delete multiple records by their IDs
|
|
|
|
|
# Args:
|
|
|
|
|
# pids: List of record IDs
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records deleted
|
|
|
|
|
with DB.atomic():
|
|
|
|
|
res = cls.model.delete().where(cls.model.id.in_(pids)).execute()
|
|
|
|
|
return res
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def filter_delete(cls, filters):
|
|
|
|
|
# Delete records matching given filters
|
|
|
|
|
# Args:
|
|
|
|
|
# filters: List of filter conditions
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records deleted
|
|
|
|
|
with DB.atomic():
|
|
|
|
|
num = cls.model.delete().where(*filters).execute()
|
|
|
|
|
return num
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def filter_update(cls, filters, update_data):
|
|
|
|
|
# Update records matching given filters
|
|
|
|
|
# Args:
|
|
|
|
|
# filters: List of filter conditions
|
|
|
|
|
# update_data: Updated field values
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records updated
|
|
|
|
|
with DB.atomic():
|
|
|
|
|
return cls.model.update(update_data).where(*filters).execute()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def cut_list(tar_list, n):
|
|
|
|
|
# Split a list into chunks of size n
|
|
|
|
|
# Args:
|
|
|
|
|
# tar_list: List to split
|
|
|
|
|
# n: Chunk size
|
|
|
|
|
# Returns:
|
|
|
|
|
# List of tuples containing chunks
|
|
|
|
|
length = len(tar_list)
|
|
|
|
|
arr = range(length)
|
|
|
|
|
result = [tuple(tar_list[x: (x + n)]) for x in arr[::n]]
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
|
|
|
|
# Get records matching IN clause filters with optional column selection
|
|
|
|
|
# Args:
|
|
|
|
|
# in_key: Field name for IN clause
|
|
|
|
|
# in_filters_list: List of values for IN clause
|
|
|
|
|
# filters: Additional filter conditions
|
|
|
|
|
# cols: List of columns to select
|
|
|
|
|
# Returns:
|
|
|
|
|
# List of matching records
|
|
|
|
|
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
|
|
|
|
if not filters:
|
|
|
|
|
filters = []
|
|
|
|
|
res_list = []
|
|
|
|
|
if cols:
|
|
|
|
|
for i in in_filters_tuple_list:
|
|
|
|
|
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
|
|
|
|
if query_records:
|
|
|
|
|
res_list.extend([query_record for query_record in query_records])
|
|
|
|
|
else:
|
|
|
|
|
for i in in_filters_tuple_list:
|
|
|
|
|
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
|
|
|
|
if query_records:
|
|
|
|
|
res_list.extend([query_record for query_record in query_records])
|
|
|
|
|
return res_list
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_query_session(cls, query_params, sessions=None):
|
|
|
|
|
|
|
|
|
|
def get_query_entity(cls, query_params, sessions=None):
|
|
|
|
|
if sessions is None:
|
|
|
|
|
sessions = cls.model.select()
|
|
|
|
|
for key, value in query_params.items():
|
|
|
|
|
@ -433,29 +252,10 @@ class BaseService(Generic[T]):
|
|
|
|
|
def get_data_count(cls, query_params: dict = None):
|
|
|
|
|
if not query_params:
|
|
|
|
|
raise Exception("参数为空")
|
|
|
|
|
sessions = cls.get_query_session(query_params)
|
|
|
|
|
sessions = cls.get_query_entity(query_params)
|
|
|
|
|
return sessions.count()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_exist(cls, query_params: dict = None):
|
|
|
|
|
return cls.get_data_count(query_params) > 0
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def update_by_id(cls, pid, data):
|
|
|
|
|
# Update a single record by ID
|
|
|
|
|
# Args:
|
|
|
|
|
# pid: Record ID
|
|
|
|
|
# data: Updated field values
|
|
|
|
|
# Returns:
|
|
|
|
|
# Number of records updated
|
|
|
|
|
data["updated_time"] = current_timestamp()
|
|
|
|
|
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
|
|
|
|
return num
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_base_permission(cls, model_data):
|
|
|
|
|
if isinstance(model_data, dict):
|
|
|
|
|
if model_data.get("created_by") != get_current_user().id:
|
|
|
|
|
raise RuntimeError("无操作权限,该操作仅创建者有此权限")
|
|
|
|
|
if model_data.created_by != get_current_user().id:
|
|
|
|
|
raise RuntimeError("无操作权限,该操作仅创建者有此权限")
|
|
|
|
|
|