You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

462 lines
17 KiB
Python

from typing import Union, Type, List, Any, TypeVar, Generic
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from core.global_context import current_session
from utils import get_uuid, current_timestamp
T = TypeVar('T')
class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型
@classmethod
def get_db(cls) -> AsyncSession:
"""获取当前请求的会话"""
session = current_session.get()
if session is None:
raise RuntimeError("No database session in context. "
"Make sure to use this service within a request context.")
return session
@classmethod
async def create(cls, **kwargs) -> T:
"""通用创建方法"""
obj = cls.model(**kwargs)
db = cls.get_db()
db.add(obj)
await db.flush()
return obj
@classmethod
def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[
BaseModel, List[BaseModel]]:
dto_list = []
if not isinstance(entity_data, list):
return dto(**entity_data.to_dict())
for entity in entity_data:
temp = entity
if not isinstance(entity, dict):
temp = entity.to_dict()
dto_list.append(dto(**temp))
return dto_list
@classmethod
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
"""Execute a database query with optional column selection and ordering.
This method provides a flexible way to query the database with various filters
and sorting options. It supports column selection, sort order control, and
additional filter conditions.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by.
**kwargs: Additional filter conditions passed as keyword arguments.
Returns:
peewee.ModelSelect: A query result containing matching records.
"""
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
def get_all(cls, cols=None, reverse=None, order_by=None):
"""Retrieve all records from the database with optional column selection and ordering.
This method fetches all records from the model's table with support for
column selection and result ordering. If no order_by is specified and reverse
is True, it defaults to ordering by created_time.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by. Defaults to 'created_time' if reverse is specified.
Returns:
peewee.ModelSelect: A query containing all matching records.
"""
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "created_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
def get(cls, **kwargs):
"""Get a single record matching the given criteria.
This method retrieves a single record from the database that matches
the specified filter conditions.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance: Single matching record.
Raises:
peewee.DoesNotExist: If no matching record is found.
"""
return cls.model.get(**kwargs)
@classmethod
def get_or_none(cls, **kwargs):
"""Get a single record or None if not found.
This method attempts to retrieve a single record matching the given criteria,
returning None if no match is found instead of raising an exception.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance or None: Matching record if found, None otherwise.
"""
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
sessions = cls.get_query_session(query_params)
return cls.auto_page(sessions, query_params)
@classmethod
def auto_page(cls, sessions, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None):
if not query_params:
query_params = {}
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
page_number = query_params.get("page_number", 1)
page_size = query_params.get("page_size", 12)
desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time")
data_count = sessions.count()
if data_count == 0:
return BasePageResp(**{
"page_number": page_number,
"page_size": page_size,
"count": data_count,
"desc": desc,
"orderby": orderby,
"data": [],
})
if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
sessions = sessions.paginate(int(page_number), int(page_size))
datas = list(sessions.dicts())
result = datas
if dto_model_class is not None:
result = [dto_model_class(**item) for item in datas]
return BasePageResp(**{
"page_number": page_number,
"page_size": page_size,
"count": data_count,
"desc": desc,
"orderby": orderby,
"data": result,
})
@classmethod
def get_list(cls, query_params: Union[dict, BaseQueryReq]):
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time")
sessions = cls.get_query_session(query_params)
if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
return list(sessions.dicts())
@classmethod
def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
desc = query_params.get("desc", "desc")
orderby = query_params.get("orderby", "created_time")
sessions = cls.model.select(cls.model.id)
sessions = cls.get_query_session(query_params, sessions)
if desc == "desc":
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
return [item["id"] for item in list(sessions.dicts())]
@classmethod
def save(cls, **kwargs):
"""Save a new record to database.
This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj > 0
@classmethod
def insert(cls, **kwargs):
"""Insert a new record with automatic ID and timestamps.
This method creates a new record with automatically generated ID and timestamp fields.
It handles the creation of created_time, create_date, updated_time, and update_date fields.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The newly created record object.
"""
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["created_time"] = current_timestamp()
# kwargs["create_date"] = datetime_format(datetime.now())
kwargs["updated_time"] = current_timestamp()
# kwargs["update_date"] = datetime_format(datetime.now())
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj > 0
@classmethod
def insert_many(cls, data_list, batch_size=100):
"""Insert multiple records in batches.
This method efficiently inserts multiple records into the database using batch processing.
It automatically sets creation timestamps for all records.
Args:
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
with DB.atomic():
for d in data_list:
if not d.get("id", None):
d["id"] = get_uuid()
d["created_time"] = current_timestamp()
# d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i: i + batch_size]).execute()
@classmethod
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.
This method updates multiple records in the database, identified by their IDs.
It automatically updates the updated_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["updated_time"] = current_timestamp()
# data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
def updated_by_id(cls, pid, data):
# Update a single record by ID
# Args:
# pid: Record ID
# data: Updated field values
# Returns:
# Number of records updated
data["updated_time"] = current_timestamp()
# data["update_date"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num > 0
@classmethod
def get_by_id(cls, pid):
# Get a record by ID
# Args:
# pid: Record ID
# Returns:
# Tuple of (success, record)
try:
obj = cls.model.get_or_none(cls.model.id == pid)
if obj:
return True, obj
except Exception:
pass
return False, None
@classmethod
def get_by_ids(cls, pids, cols=None):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
def get_last_by_create_time(cls):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
latest = cls.model.select().order_by(cls.model.created_time.desc()).first()
return latest
@classmethod
def delete_by_id(cls, pid):
# Delete a record by ID
# Args:
# pid: Record ID
# Returns:
# Number of records deleted
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
def delete_by_ids(cls, pids):
# Delete multiple records by their IDs
# Args:
# pids: List of record IDs
# Returns:
# Number of records deleted
with DB.atomic():
res = cls.model.delete().where(cls.model.id.in_(pids)).execute()
return res
@classmethod
def filter_delete(cls, filters):
# Delete records matching given filters
# Args:
# filters: List of filter conditions
# Returns:
# Number of records deleted
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
def filter_update(cls, filters, update_data):
# Update records matching given filters
# Args:
# filters: List of filter conditions
# update_data: Updated field values
# Returns:
# Number of records updated
with DB.atomic():
return cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
# Split a list into chunks of size n
# Args:
# tar_list: List to split
# n: Chunk size
# Returns:
# List of tuples containing chunks
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x: (x + n)]) for x in arr[::n]]
return result
@classmethod
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
# Get records matching IN clause filters with optional column selection
# Args:
# in_key: Field name for IN clause
# in_filters_list: List of values for IN clause
# filters: Additional filter conditions
# cols: List of columns to select
# Returns:
# List of matching records
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list
@classmethod
def get_query_session(cls, query_params, sessions=None):
if sessions is None:
sessions = cls.model.select()
for key, value in query_params.items():
if hasattr(cls.model, key):
field = getattr(cls.model, key)
sessions = sessions.where(field == value)
return sessions
@classmethod
def get_data_count(cls, query_params: dict = None):
if not query_params:
raise Exception("参数为空")
sessions = cls.get_query_session(query_params)
return sessions.count()
@classmethod
def is_exist(cls, query_params: dict = None):
return cls.get_data_count(query_params) > 0
@classmethod
def update_by_id(cls, pid, data):
# Update a single record by ID
# Args:
# pid: Record ID
# data: Updated field values
# Returns:
# Number of records updated
data["updated_time"] = current_timestamp()
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
def check_base_permission(cls, model_data):
if isinstance(model_data, dict):
if model_data.get("created_by") != get_current_user().id:
raise RuntimeError("无操作权限,该操作仅创建者有此权限")
if model_data.created_by != get_current_user().id:
raise RuntimeError("无操作权限,该操作仅创建者有此权限")