|
|
|
@ -1,13 +1,13 @@
|
|
|
|
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
|
|
|
|
from typing import Union, Type, List, Any, TypeVar, Generic, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi_pagination import Params
|
|
|
|
from fastapi_pagination import Params
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from sqlalchemy import func
|
|
|
|
from sqlalchemy import func
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
from sqlmodel import SQLModel
|
|
|
|
|
|
|
|
|
|
|
|
from core.global_context import current_session
|
|
|
|
from entity import with_db_session
|
|
|
|
from entity import DbBaseModel
|
|
|
|
|
|
|
|
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
|
|
|
|
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
|
|
|
|
from utils import get_uuid
|
|
|
|
from utils import get_uuid
|
|
|
|
|
|
|
|
|
|
|
|
@ -19,20 +19,12 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询)
|
|
|
|
session.scalar: 直接明确获取一条数据,可以直接返回,无需额外处理
|
|
|
|
session.scalar: 直接明确获取一条数据,可以直接返回,无需额外处理
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
T = TypeVar('T', bound=DbBaseModel)
|
|
|
|
T = TypeVar('T', bound=SQLModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseService(Generic[T]):
|
|
|
|
class BaseService(Generic[T]):
|
|
|
|
model: Type[T] # 子类必须指定模型
|
|
|
|
model: Type[T] # 子类必须指定模型
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def get_db(cls) -> AsyncSession:
|
|
|
|
|
|
|
|
"""获取当前请求的会话"""
|
|
|
|
|
|
|
|
session = current_session.get()
|
|
|
|
|
|
|
|
if session is None:
|
|
|
|
|
|
|
|
raise RuntimeError("No database session in context. "
|
|
|
|
|
|
|
|
"Make sure to use this service within a request context.")
|
|
|
|
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
|
|
|
|
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
|
|
|
|
if stmt is None:
|
|
|
|
if stmt is None:
|
|
|
|
@ -41,8 +33,15 @@ class BaseService(Generic[T]):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
stmt = cls.model.select()
|
|
|
|
stmt = cls.model.select()
|
|
|
|
for key, value in query_params.items():
|
|
|
|
for key, value in query_params.items():
|
|
|
|
if hasattr(cls.model, key) and value is not None:
|
|
|
|
if value is None:
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
if isinstance(key, str) and hasattr(cls.model, key): # 第一步:先确定 key 的类型
|
|
|
|
|
|
|
|
# 第二步:根据类型,用对应的方式处理
|
|
|
|
field = getattr(cls.model, key)
|
|
|
|
field = getattr(cls.model, key)
|
|
|
|
|
|
|
|
elif hasattr(key, 'model') and key.model is cls.model:
|
|
|
|
|
|
|
|
field = key
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
continue
|
|
|
|
stmt = stmt.where(field == value)
|
|
|
|
stmt = stmt.where(field == value)
|
|
|
|
return stmt
|
|
|
|
return stmt
|
|
|
|
|
|
|
|
|
|
|
|
@ -55,7 +54,7 @@ class BaseService(Generic[T]):
|
|
|
|
for entity in entity_data:
|
|
|
|
for entity in entity_data:
|
|
|
|
temp = entity
|
|
|
|
temp = entity
|
|
|
|
if not isinstance(entity, dict):
|
|
|
|
if not isinstance(entity, dict):
|
|
|
|
temp = entity.to_dict()
|
|
|
|
temp = entity.model_dump()
|
|
|
|
dto_list.append(dto(**temp))
|
|
|
|
dto_list.append(dto(**temp))
|
|
|
|
return dto_list
|
|
|
|
return dto_list
|
|
|
|
|
|
|
|
|
|
|
|
@ -73,9 +72,10 @@ class BaseService(Generic[T]):
|
|
|
|
return await cls.auto_page(query_stmt, query_params)
|
|
|
|
return await cls.auto_page(query_stmt, query_params)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
|
|
|
|
dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]:
|
|
|
|
dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> \
|
|
|
|
session = session or cls.get_db()
|
|
|
|
BasePageResp[T]:
|
|
|
|
if not query_params:
|
|
|
|
if not query_params:
|
|
|
|
query_params = {}
|
|
|
|
query_params = {}
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
@ -115,8 +115,9 @@ class BaseService(Generic[T]):
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]:
|
|
|
|
@with_db_session()
|
|
|
|
session = session or cls.get_db()
|
|
|
|
async def get_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> List[
|
|
|
|
|
|
|
|
T]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
@ -128,12 +129,24 @@ class BaseService(Generic[T]):
|
|
|
|
query_stmt = query_stmt.order_by(field.desc())
|
|
|
|
query_stmt = query_stmt.order_by(field.desc())
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
query_stmt = query_stmt.order_by(field.asc())
|
|
|
|
query_stmt = query_stmt.order_by(field.asc())
|
|
|
|
|
|
|
|
if query_params.get("limit", None) is not None:
|
|
|
|
|
|
|
|
query_stmt = query_stmt.limit(query_params.get("limit"))
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
exec_result = await session.execute(query_stmt)
|
|
|
|
return list(exec_result.scalars().all())
|
|
|
|
return list(exec_result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]:
|
|
|
|
@with_db_session()
|
|
|
|
session = session or cls.get_db()
|
|
|
|
async def get_list_json(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
|
|
|
|
|
|
|
|
List[
|
|
|
|
|
|
|
|
T]:
|
|
|
|
|
|
|
|
resp_list = await cls.get_list(query_params, session=session)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [i.model_dump() for i in resp_list]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
|
|
|
|
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
|
|
|
|
|
|
|
|
List[str]:
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v is not None}
|
|
|
|
@ -149,16 +162,24 @@ class BaseService(Generic[T]):
|
|
|
|
return [item["id"] for item in exec_result.scalars().all()]
|
|
|
|
return [item["id"] for item in exec_result.scalars().all()]
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T:
|
|
|
|
async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
sample_obj = cls.model(**kwargs)
|
|
|
|
session.add(sample_obj)
|
|
|
|
session.add(sample_obj)
|
|
|
|
await session.flush()
|
|
|
|
await session.flush()
|
|
|
|
return sample_obj
|
|
|
|
return sample_obj
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
|
|
|
|
async def save_entity(cls, db_model: SQLModel, *, session: Optional[AsyncSession] = None) -> T:
|
|
|
|
|
|
|
|
session.add(db_model)
|
|
|
|
|
|
|
|
await session.flush()
|
|
|
|
|
|
|
|
return db_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def insert_many(cls, data_list, batch_size=100, *, session: Optional[AsyncSession] = None) -> None:
|
|
|
|
async def insert_many(cls, data_list, batch_size=100, *, session: Optional[AsyncSession] = None) -> None:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
async with session:
|
|
|
|
async with session:
|
|
|
|
for d in data_list:
|
|
|
|
for d in data_list:
|
|
|
|
if not d.get("id", None):
|
|
|
|
if not d.get("id", None):
|
|
|
|
@ -168,30 +189,41 @@ class BaseService(Generic[T]):
|
|
|
|
session.add_all(data_list[i: i + batch_size])
|
|
|
|
session.add_all(data_list[i: i + batch_size])
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def update_by_id(cls, pid, data, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
async def update_by_id(cls, pid, data, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
|
|
|
|
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
|
|
|
|
result = await session.execute(update_stmt)
|
|
|
|
result = await session.execute(update_stmt)
|
|
|
|
return result.rowcount
|
|
|
|
return result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def update_many_by_id(cls, data_list, *, session: Optional[AsyncSession] = None) -> None:
|
|
|
|
async def update_many_by_id(cls, data_list, *, session: Optional[AsyncSession] = None) -> None:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
async with session:
|
|
|
|
async with session:
|
|
|
|
for data in data_list:
|
|
|
|
for data in data_list:
|
|
|
|
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
|
|
|
|
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
|
|
|
|
await session.execute(stmt)
|
|
|
|
await session.execute(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T:
|
|
|
|
async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
stmt = cls.model.select().where(cls.model.id == pid)
|
|
|
|
stmt = cls.model.select().where(cls.model.id == pid)
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
return await session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
|
|
|
|
async def get_one(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> T:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(query_params, dict):
|
|
|
|
|
|
|
|
query_params = query_params.model_dump()
|
|
|
|
|
|
|
|
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
|
|
|
|
|
|
|
|
query_stmt = cls.get_query_stmt(query_params)
|
|
|
|
|
|
|
|
return await session.scalar(query_stmt)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]:
|
|
|
|
async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
if cols:
|
|
|
|
if cols:
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
objs = cls.model.select(*cols)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -201,22 +233,35 @@ class BaseService(Generic[T]):
|
|
|
|
return list(result.all())
|
|
|
|
return list(result.all())
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
|
|
|
|
async def delete(cls, delete_params: dict, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del_stmt = cls.model.delete()
|
|
|
|
|
|
|
|
for k, v in delete_params.items():
|
|
|
|
|
|
|
|
del_stmt = del_stmt.where(getattr(cls.model, k) == v)
|
|
|
|
|
|
|
|
exec_result = await session.execute(del_stmt)
|
|
|
|
|
|
|
|
return exec_result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def delete_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
async def delete_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id == pid)
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id == pid)
|
|
|
|
exec_result = await session.execute(del_stmt)
|
|
|
|
exec_result = await session.execute(del_stmt)
|
|
|
|
return exec_result.rowcount
|
|
|
|
return exec_result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def delete_by_ids(cls, pids, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
async def delete_by_ids(cls, pids, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
|
|
|
|
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
|
|
|
|
result = await session.execute(del_stmt)
|
|
|
|
result = await session.execute(del_stmt)
|
|
|
|
return result.rowcount
|
|
|
|
return result.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
@with_db_session()
|
|
|
|
async def get_data_count(cls, query_params: dict = None, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
async def get_data_count(cls, query_params: dict = None, *, session: Optional[AsyncSession] = None) -> int:
|
|
|
|
session = session or cls.get_db()
|
|
|
|
|
|
|
|
if not query_params:
|
|
|
|
if not query_params:
|
|
|
|
raise Exception("参数为空")
|
|
|
|
raise Exception("参数为空")
|
|
|
|
stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])
|
|
|
|
stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])
|
|
|
|
|