|
|
|
@ -1,11 +1,10 @@
|
|
|
|
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
|
|
|
|
from typing import Union, Type, List, Any, TypeVar, Generic
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi_pagination import Params
|
|
|
|
from fastapi_pagination import Params
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from sqlalchemy import func
|
|
|
|
from sqlalchemy import func
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from sqlmodel import SQLModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from core.global_context import current_session
|
|
|
|
from core.global_context import current_session
|
|
|
|
from entity import DbBaseModel
|
|
|
|
from entity import DbBaseModel
|
|
|
|
@ -22,6 +21,7 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
T = TypeVar('T', bound=DbBaseModel)
|
|
|
|
T = TypeVar('T', bound=DbBaseModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseService(Generic[T]):
|
|
|
|
class BaseService(Generic[T]):
|
|
|
|
model: Type[T] # 子类必须指定模型
|
|
|
|
model: Type[T] # 子类必须指定模型
|
|
|
|
|
|
|
|
|
|
|
|
@ -66,7 +66,7 @@ class BaseService(Generic[T]):
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp:
|
|
|
|
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
@ -75,7 +75,7 @@ class BaseService(Generic[T]):
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
dto_model_class: Type[BaseModel] = None)->BasePageResp:
|
|
|
|
dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]:
|
|
|
|
if not query_params:
|
|
|
|
if not query_params:
|
|
|
|
query_params = {}
|
|
|
|
query_params = {}
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
@ -116,7 +116,7 @@ class BaseService(Generic[T]):
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]:
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
@ -150,7 +150,7 @@ class BaseService(Generic[T]):
|
|
|
|
return [item["id"] for item in exec_result.scalars().all()]
|
|
|
|
return [item["id"] for item in exec_result.scalars().all()]
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def save(cls, **kwargs)->T:
|
|
|
|
async def save(cls, **kwargs) -> T:
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
session = cls.get_db()
|
|
|
|
session = cls.get_db()
|
|
|
|
session.add(sample_obj)
|
|
|
|
session.add(sample_obj)
|
|
|
|
@ -158,7 +158,7 @@ class BaseService(Generic[T]):
|
|
|
|
return sample_obj
|
|
|
|
return sample_obj
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def insert_many(cls, data_list, batch_size=100)->None:
|
|
|
|
async def insert_many(cls, data_list, batch_size=100) -> None:
|
|
|
|
async with cls.get_db() as session:
|
|
|
|
async with cls.get_db() as session:
|
|
|
|
for d in data_list:
|
|
|
|
for d in data_list:
|
|
|
|
if not d.get("id", None):
|
|
|
|
if not d.get("id", None):
|
|
|
|
@ -168,27 +168,27 @@ class BaseService(Generic[T]):
|
|
|
|
session.add_all(data_list[i: i + batch_size])
|
|
|
|
session.add_all(data_list[i: i + batch_size])
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def update_by_id(cls, pid, data)-> int:
|
|
|
|
async def update_by_id(cls, pid, data) -> int:
|
|
|
|
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
|
|
|
|
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
|
|
|
|
session = cls.get_db()
|
|
|
|
session = cls.get_db()
|
|
|
|
result = await session.execute(update_stmt)
|
|
|
|
result = await session.execute(update_stmt)
|
|
|
|
return result.rowcount
|
|
|
|
return result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def update_many_by_id(cls, data_list)->None:
|
|
|
|
async def update_many_by_id(cls, data_list) -> None:
|
|
|
|
async with cls.get_db() as session:
|
|
|
|
async with cls.get_db() as session:
|
|
|
|
for data in data_list:
|
|
|
|
for data in data_list:
|
|
|
|
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
|
|
|
|
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
|
|
|
|
await session.execute(stmt)
|
|
|
|
await session.execute(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_by_id(cls, pid)->T:
|
|
|
|
async def get_by_id(cls, pid) -> T:
|
|
|
|
stmt = cls.model.select().where(cls.model.id == pid)
|
|
|
|
stmt = cls.model.select().where(cls.model.id == pid)
|
|
|
|
session = cls.get_db()
|
|
|
|
session = cls.get_db()
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_by_ids(cls, pids, cols=None)->List[T]:
|
|
|
|
async def get_by_ids(cls, pids, cols=None) -> List[T]:
|
|
|
|
if cols:
|
|
|
|
if cols:
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -199,14 +199,14 @@ class BaseService(Generic[T]):
|
|
|
|
return list(result.all())
|
|
|
|
return list(result.all())
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def delete_by_id(cls, pid)-> int:
|
|
|
|
async def delete_by_id(cls, pid) -> int:
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id == pid)
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id == pid)
|
|
|
|
session = cls.get_db()
|
|
|
|
session = cls.get_db()
|
|
|
|
exec_result = await session.execute(del_stmt)
|
|
|
|
exec_result = await session.execute(del_stmt)
|
|
|
|
return exec_result.rowcount
|
|
|
|
return exec_result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def delete_by_ids(cls, pids)-> int:
|
|
|
|
async def delete_by_ids(cls, pids) -> int:
|
|
|
|
session = cls.get_db()
|
|
|
|
session = cls.get_db()
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
|
|
|
|
result = await session.execute(del_stmt)
|
|
|
|
result = await session.execute(del_stmt)
|
|
|
|
|