|
|
|
@ -1,19 +1,25 @@
|
|
|
|
from typing import Union, Type, List, Any, TypeVar, Generic
|
|
|
|
from typing import Union, Type, List, Any
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi_pagination import Params, Page
|
|
|
|
from fastapi_pagination import Params
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from sqlalchemy import Select, select, func
|
|
|
|
from sqlalchemy import func
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
|
|
from core.global_context import current_session
|
|
|
|
from core.global_context import current_session
|
|
|
|
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
|
|
|
|
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
|
|
|
|
from utils import get_uuid, current_timestamp
|
|
|
|
from utils import get_uuid, current_timestamp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
session.execute: 执行任意数据库操作语句,返回结果需要额外处理获取 数据形式:Row对象(类似元组)
|
|
|
|
|
|
|
|
-- exec_result.scalars().all(): 从session.execute的执行结果中获取全部的数据
|
|
|
|
|
|
|
|
session.scalars: 只适合单模型查询(不适合指定列或连表查询),返回结果需要处理 数据形式:直接返回模型实例(如User对象)
|
|
|
|
|
|
|
|
-- 处理exec_result.all(): 从session.scalars的执行结果中获取全部的数据
|
|
|
|
|
|
|
|
session.scalar: 直接明确获取一条数据,可以直接返回,无需额外处理
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
class BaseService:
|
|
|
|
class BaseService:
|
|
|
|
model=None # 子类必须指定模型
|
|
|
|
model = None # 子类必须指定模型
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_db(cls) -> AsyncSession:
|
|
|
|
def get_db(cls) -> AsyncSession:
|
|
|
|
@ -23,13 +29,24 @@ class BaseService:
|
|
|
|
raise RuntimeError("No database session in context. "
|
|
|
|
raise RuntimeError("No database session in context. "
|
|
|
|
"Make sure to use this service within a request context.")
|
|
|
|
"Make sure to use this service within a request context.")
|
|
|
|
return session
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[
|
|
|
|
def get_query_stmt(cls, query_params, sessions=None,*,fields:list=None):
|
|
|
|
|
|
|
|
if sessions is None:
|
|
|
|
|
|
|
|
if fields:
|
|
|
|
|
|
|
|
sessions = cls.model.select(*fields)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
sessions = cls.model.select()
|
|
|
|
|
|
|
|
for key, value in query_params.items():
|
|
|
|
|
|
|
|
if hasattr(cls.model, key) and value is not None:
|
|
|
|
|
|
|
|
field = getattr(cls.model, key)
|
|
|
|
|
|
|
|
sessions = sessions.where(field == value)
|
|
|
|
|
|
|
|
return sessions
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[
|
|
|
|
BaseModel, List[BaseModel]]:
|
|
|
|
BaseModel, List[BaseModel]]:
|
|
|
|
dto_list = []
|
|
|
|
dto_list = []
|
|
|
|
if not isinstance(entity_data, list):
|
|
|
|
if not isinstance(entity_data, list):
|
|
|
|
return dto(**entity_data.to_dict())
|
|
|
|
return dto(**entity_data.model_dump())
|
|
|
|
for entity in entity_data:
|
|
|
|
for entity in entity_data:
|
|
|
|
temp = entity
|
|
|
|
temp = entity
|
|
|
|
if not isinstance(entity, dict):
|
|
|
|
if not isinstance(entity, dict):
|
|
|
|
@ -37,22 +54,17 @@ class BaseService:
|
|
|
|
dto_list.append(dto(**temp))
|
|
|
|
dto_list.append(dto(**temp))
|
|
|
|
return dto_list
|
|
|
|
return dto_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_entity = cls.get_query_entity(query_params)
|
|
|
|
query_stmt = cls.get_query_stmt(query_params)
|
|
|
|
return await cls.auto_page(query_entity, query_params)
|
|
|
|
return await cls.auto_page(query_stmt, query_params)
|
|
|
|
# @classmethod
|
|
|
|
|
|
|
|
# def count_query(cls,query: Select) -> Select:
|
|
|
|
|
|
|
|
# # type: ignore
|
|
|
|
|
|
|
|
# return select(func.count("*")).select_from(count_subquery)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def auto_page(cls, query_entity, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
dto_model_class: Type[BaseModel] = None):
|
|
|
|
dto_model_class: Type[BaseModel] = None):
|
|
|
|
if not query_params:
|
|
|
|
if not query_params:
|
|
|
|
query_params = {}
|
|
|
|
query_params = {}
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
@ -61,11 +73,8 @@ class BaseService:
|
|
|
|
page_size = query_params.get("page_size", 12)
|
|
|
|
page_size = query_params.get("page_size", 12)
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
# data_count = await sessions.count()
|
|
|
|
|
|
|
|
session = cls.get_db()
|
|
|
|
session = cls.get_db()
|
|
|
|
|
|
|
|
data_count = None
|
|
|
|
# data_count = session.scalar(cls.count_query(query_entity))
|
|
|
|
|
|
|
|
data_count =None
|
|
|
|
|
|
|
|
if data_count == 0:
|
|
|
|
if data_count == 0:
|
|
|
|
return BasePageResp(**{
|
|
|
|
return BasePageResp(**{
|
|
|
|
"page_number": page_number,
|
|
|
|
"page_number": page_number,
|
|
|
|
@ -76,17 +85,12 @@ class BaseService:
|
|
|
|
"data": [],
|
|
|
|
"data": [],
|
|
|
|
})
|
|
|
|
})
|
|
|
|
if desc == "desc":
|
|
|
|
if desc == "desc":
|
|
|
|
|
|
|
|
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).desc())
|
|
|
|
query_entity = query_entity.order_by(getattr(cls.model,orderby).desc())
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
query_entity = query_entity.order_by(getattr(cls.model,orderby).asc())
|
|
|
|
query_stmt = query_stmt.order_by(getattr(cls.model, orderby).asc())
|
|
|
|
query_page_result=await paginate(session,
|
|
|
|
query_page_result = await paginate(session,
|
|
|
|
query_entity,
|
|
|
|
query_stmt,
|
|
|
|
Params(page=page_number,size=page_size))
|
|
|
|
Params(page=page_number, size=page_size))
|
|
|
|
# query_entity = query_entity.offset((page_number - 1) * page_size).limit(page_size)
|
|
|
|
|
|
|
|
# query_exec_result = await session.execute(query_entity)
|
|
|
|
|
|
|
|
# result = query_exec_result.scalars().all()
|
|
|
|
|
|
|
|
# return query_page_result
|
|
|
|
|
|
|
|
result = query_page_result.items
|
|
|
|
result = query_page_result.items
|
|
|
|
if dto_model_class is not None:
|
|
|
|
if dto_model_class is not None:
|
|
|
|
result = [dto_model_class(**item) for item in result]
|
|
|
|
result = [dto_model_class(**item) for item in result]
|
|
|
|
@ -101,161 +105,112 @@ class BaseService:
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_list(cls, query_params: Union[dict, BaseQueryReq]):
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq]):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
sessions = cls.get_query_entity(query_params)
|
|
|
|
query_stmt = cls.get_query_stmt(query_params)
|
|
|
|
|
|
|
|
field = getattr(cls.model, orderby)
|
|
|
|
if desc == "desc":
|
|
|
|
if desc == "desc":
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
query_stmt = query_stmt.order_by(field.desc())
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
query_stmt = query_stmt.order_by(field.asc())
|
|
|
|
|
|
|
|
session = cls.get_db()
|
|
|
|
return sessions.scalars().all()
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
|
|
|
|
return exec_result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
|
|
|
|
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
desc = query_params.get("desc", "desc")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
orderby = query_params.get("orderby", "created_time")
|
|
|
|
sessions = cls.model.select(cls.model.id)
|
|
|
|
query_stmt = cls.model.select(cls.model.id)
|
|
|
|
sessions = cls.get_query_entity(query_params, sessions)
|
|
|
|
query_stmt = cls.get_query_stmt(query_params, query_stmt)
|
|
|
|
if desc == "desc":
|
|
|
|
if desc == "desc":
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc())
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
return [item["id"] for item in sessions.scalars().all()]
|
|
|
|
session = cls.get_db()
|
|
|
|
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
|
|
|
|
return [item["id"] for item in exec_result.scalars().all()]
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def save(cls, **kwargs):
|
|
|
|
async def save(cls, **kwargs):
|
|
|
|
"""Save a new record to database.
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
|
|
|
|
session = cls.get_db()
|
|
|
|
This method creates a new record in the database with the provided field values,
|
|
|
|
session.add(sample_obj)
|
|
|
|
forcing an insert operation rather than an update.
|
|
|
|
await session.flush()
|
|
|
|
|
|
|
|
return sample_obj
|
|
|
|
Args:
|
|
|
|
|
|
|
|
**kwargs: Record field values as keyword arguments.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Model instance: The created record object.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# todo
|
|
|
|
|
|
|
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
|
|
|
|
|
|
|
return sample_obj > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def insert_many(cls, data_list, batch_size=100):
|
|
|
|
async def insert_many(cls, data_list, batch_size=100):
|
|
|
|
"""Insert multiple records in batches.
|
|
|
|
async with cls.get_db() as session:
|
|
|
|
|
|
|
|
|
|
|
|
This method efficiently inserts multiple records into the database using batch processing.
|
|
|
|
|
|
|
|
It automatically sets creation timestamps for all records.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
data_list (list): List of dictionaries containing record data to insert.
|
|
|
|
|
|
|
|
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with DB.atomic():
|
|
|
|
|
|
|
|
for d in data_list:
|
|
|
|
for d in data_list:
|
|
|
|
|
|
|
|
|
|
|
|
if not d.get("id", None):
|
|
|
|
if not d.get("id", None):
|
|
|
|
d["id"] = get_uuid()
|
|
|
|
d["id"] = get_uuid()
|
|
|
|
d["created_time"] = current_timestamp()
|
|
|
|
|
|
|
|
# d["create_date"] = datetime_format(datetime.now())
|
|
|
|
|
|
|
|
for i in range(0, len(data_list), batch_size):
|
|
|
|
for i in range(0, len(data_list), batch_size):
|
|
|
|
cls.model.insert_many(data_list[i: i + batch_size]).execute()
|
|
|
|
session.add_all(data_list[i: i + batch_size])
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def update_by_id(cls, pid, data):
|
|
|
|
async def update_by_id(cls, pid, data):
|
|
|
|
# Update a single record by ID
|
|
|
|
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
|
|
|
|
# Args:
|
|
|
|
session = cls.get_db()
|
|
|
|
# pid: Record ID
|
|
|
|
result = await session.execute(update_stmt)
|
|
|
|
# data: Updated field values
|
|
|
|
return result.rowcount
|
|
|
|
# Returns:
|
|
|
|
|
|
|
|
# Number of records updated
|
|
|
|
|
|
|
|
data["updated_time"] = current_timestamp()
|
|
|
|
|
|
|
|
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
|
|
|
|
|
|
|
return num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def update_many_by_id(cls, data_list):
|
|
|
|
async def update_many_by_id(cls, data_list):
|
|
|
|
"""Update multiple records by their IDs.
|
|
|
|
async with cls.get_db() as session:
|
|
|
|
|
|
|
|
|
|
|
|
This method updates multiple records in the database, identified by their IDs.
|
|
|
|
|
|
|
|
It automatically updates the updated_time and update_date fields for each record.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
data_list (list): List of dictionaries containing record data to update.
|
|
|
|
|
|
|
|
Each dictionary must include an 'id' field.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with DB.atomic():
|
|
|
|
|
|
|
|
for data in data_list:
|
|
|
|
for data in data_list:
|
|
|
|
data["updated_time"] = current_timestamp()
|
|
|
|
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
|
|
|
|
# data["update_date"] = datetime_format(datetime.now())
|
|
|
|
await session.execute(stmt)
|
|
|
|
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_by_id(cls, pid):
|
|
|
|
async def get_by_id(cls, pid):
|
|
|
|
# Get a record by ID
|
|
|
|
stmt = cls.model.select(cls.model.id == pid)
|
|
|
|
# Args:
|
|
|
|
session = cls.get_db()
|
|
|
|
# pid: Record ID
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
# Returns:
|
|
|
|
|
|
|
|
# Tuple of (success, record)
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
obj = cls.model.get_or_none(cls.model.id == pid)
|
|
|
|
|
|
|
|
if obj:
|
|
|
|
|
|
|
|
return True, obj
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_by_ids(cls, pids, cols=None):
|
|
|
|
async def get_by_ids(cls, pids, cols=None):
|
|
|
|
# Get multiple records by their IDs
|
|
|
|
|
|
|
|
# Args:
|
|
|
|
|
|
|
|
# pids: List of record IDs
|
|
|
|
|
|
|
|
# cols: List of columns to select
|
|
|
|
|
|
|
|
# Returns:
|
|
|
|
|
|
|
|
# Query of matching records
|
|
|
|
|
|
|
|
if cols:
|
|
|
|
if cols:
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
objs = cls.model.select()
|
|
|
|
objs = cls.model.select()
|
|
|
|
return objs.where(cls.model.id.in_(pids))
|
|
|
|
stmt = objs.where(cls.model.id.in_(pids))
|
|
|
|
|
|
|
|
session = cls.get_db()
|
|
|
|
|
|
|
|
result = await session.scalars(stmt)
|
|
|
|
|
|
|
|
return result.all()
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def delete_by_id(cls, pid):
|
|
|
|
async def delete_by_id(cls, pid):
|
|
|
|
...
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id == pid)
|
|
|
|
@classmethod
|
|
|
|
session = cls.get_db()
|
|
|
|
def delete_by_ids(cls, pids):
|
|
|
|
exec_result = await session.execute(del_stmt)
|
|
|
|
...
|
|
|
|
return exec_result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_query_entity(cls, query_params, sessions=None):
|
|
|
|
async def delete_by_ids(cls, pids):
|
|
|
|
if sessions is None:
|
|
|
|
session = cls.get_db()
|
|
|
|
sessions = cls.model.select()
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
|
|
|
|
for key, value in query_params.items():
|
|
|
|
result = await session.execute(del_stmt)
|
|
|
|
if hasattr(cls.model, key):
|
|
|
|
return result.rowcount
|
|
|
|
field = getattr(cls.model, key)
|
|
|
|
|
|
|
|
sessions = sessions.where(field == value)
|
|
|
|
|
|
|
|
return sessions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_data_count(cls, query_params: dict = None):
|
|
|
|
async def get_data_count(cls, query_params: dict = None)->int:
|
|
|
|
if not query_params:
|
|
|
|
if not query_params:
|
|
|
|
raise Exception("参数为空")
|
|
|
|
raise Exception("参数为空")
|
|
|
|
sessions = cls.get_query_entity(query_params)
|
|
|
|
stmt = cls.get_query_stmt(query_params,fields=[func.count(cls.model.id)])
|
|
|
|
return sessions.count()
|
|
|
|
# stmt = cls.get_query_stmt(query_params)
|
|
|
|
|
|
|
|
session = cls.get_db()
|
|
|
|
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def is_exist(cls, query_params: dict = None):
|
|
|
|
|
|
|
|
return cls.get_data_count(query_params) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
async def is_exist(cls, query_params: dict = None):
|
|
|
|
|
|
|
|
return await cls.get_data_count(query_params) > 0
|
|
|
|
|