feat: 支持使用自建的数据库会话

main
chenzhirong 3 months ago
parent ec16a62ee1
commit 436ae236e1

@ -5,6 +5,9 @@ from enum import IntEnum, StrEnum
class IsDelete(IntEnum):
NO_DELETE = 0
DELETE = 1
class UserRoleEnum(StrEnum):
USER="user"
ADMIN="ADMIN"
class LLMType(StrEnum):
CHAT = 'chat'
@ -12,4 +15,4 @@ class LLMType(StrEnum):
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'
TTS = 'tts'
TTS = 'tts'

@ -1,5 +1,6 @@
from sqlmodel import SQLModel
from common.global_enums import UserRoleEnum
from entity import DbBaseModel, engine
@ -13,3 +14,4 @@ class User(DbBaseModel, table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str
password: str
user_role: UserRoleEnum

@ -17,4 +17,6 @@ class DbSessionMiddleWare(BaseHTTPMiddleware):
finally:
# 重置上下文变量
current_session.reset(token)
# 无论成功与否,都必须关闭会话
await session.close()
return response

@ -1,4 +1,4 @@
from typing import Union, Type, List, Any, TypeVar, Generic
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate
@ -21,7 +21,6 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额
"""
T = TypeVar('T', bound=DbBaseModel)
class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型
@ -35,17 +34,17 @@ class BaseService(Generic[T]):
return session
@classmethod
def get_query_stmt(cls, query_params, sessions=None, *, fields: list = None):
if sessions is None:
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
if stmt is None:
if fields:
sessions = cls.model.select(*fields)
stmt = cls.model.select(*fields)
else:
sessions = cls.model.select()
stmt = cls.model.select()
for key, value in query_params.items():
if hasattr(cls.model, key) and value is not None:
field = getattr(cls.model, key)
sessions = sessions.where(field == value)
return sessions
stmt = stmt.where(field == value)
return stmt
@classmethod
def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[
@ -66,16 +65,17 @@ class BaseService(Generic[T]):
pass
@classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]:
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp[T]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
query_params = {k: v for k, v in query_params.items() if v not in [None,""]}
query_stmt = cls.get_query_stmt(query_params)
return await cls.auto_page(query_stmt, query_params)
@classmethod
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]:
dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]:
session = session or cls.get_db()
if not query_params:
query_params = {}
if not isinstance(query_params, dict):
@ -84,7 +84,6 @@ class BaseService(Generic[T]):
page_size = query_params.get("page_size", 12)
sort = query_params.get("sort", "desc")
orderby = query_params.get("orderby", "created_time")
session = cls.get_db()
data_count = None
if data_count == 0:
return BasePageResp(**{
@ -116,10 +115,11 @@ class BaseService(Generic[T]):
})
@classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]:
async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]:
session = session or cls.get_db()
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
query_params = {k: v for k, v in query_params.items() if v not in [None,""]}
sort = query_params.get("sort", "desc")
orderby = query_params.get("orderby", "created_time")
query_stmt = cls.get_query_stmt(query_params)
@ -128,12 +128,12 @@ class BaseService(Generic[T]):
query_stmt = query_stmt.order_by(field.desc())
else:
query_stmt = query_stmt.order_by(field.asc())
session = cls.get_db()
exec_result = await session.execute(query_stmt)
return list(exec_result.scalars().all())
@classmethod
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[str]:
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]:
session = session or cls.get_db()
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -145,21 +145,21 @@ class BaseService(Generic[T]):
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc())
else:
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc())
session = cls.get_db()
exec_result = await session.execute(query_stmt)
return [item["id"] for item in exec_result.scalars().all()]
@classmethod
async def save(cls, **kwargs) -> T:
async def save(cls,*, session: Optional[AsyncSession] = None, **kwargs)->T:
session = session or cls.get_db()
sample_obj = cls.model(**kwargs)
session = cls.get_db()
session.add(sample_obj)
await session.flush()
return sample_obj
@classmethod
async def insert_many(cls, data_list, batch_size=100) -> None:
async with cls.get_db() as session:
async def insert_many(cls, data_list, batch_size=100,*, session: Optional[AsyncSession] = None)->None:
session = session or cls.get_db()
async with session:
for d in data_list:
if not d.get("id", None):
d["id"] = get_uuid()
@ -168,59 +168,61 @@ class BaseService(Generic[T]):
session.add_all(data_list[i: i + batch_size])
@classmethod
async def update_by_id(cls, pid, data) -> int:
async def update_by_id(cls, pid, data,*, session: Optional[AsyncSession] = None)-> int:
session = session or cls.get_db()
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
session = cls.get_db()
result = await session.execute(update_stmt)
return result.rowcount
@classmethod
async def update_many_by_id(cls, data_list) -> None:
async with cls.get_db() as session:
async def update_many_by_id(cls, data_list,*, session: Optional[AsyncSession] = None)->None:
session = session or cls.get_db()
async with session:
for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt)
@classmethod
async def get_by_id(cls, pid) -> T:
async def get_by_id(cls, pid,*, session: Optional[AsyncSession] = None)->T:
session = session or cls.get_db()
stmt = cls.model.select().where(cls.model.id == pid)
session = cls.get_db()
return await session.scalar(stmt)
@classmethod
async def get_by_ids(cls, pids, cols=None) -> List[T]:
async def get_by_ids(cls, pids, cols=None,*, session: Optional[AsyncSession] = None)->List[T]:
session = session or cls.get_db()
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
stmt = objs.where(cls.model.id.in_(pids))
session = cls.get_db()
result = await session.scalars(stmt)
return list(result.all())
@classmethod
async def delete_by_id(cls, pid) -> int:
async def delete_by_id(cls, pid,*, session: Optional[AsyncSession] = None)-> int:
session = session or cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id == pid)
session = cls.get_db()
exec_result = await session.execute(del_stmt)
return exec_result.rowcount
@classmethod
async def delete_by_ids(cls, pids) -> int:
session = cls.get_db()
async def delete_by_ids(cls, pids,*, session: Optional[AsyncSession] = None)-> int:
session = session or cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt)
return result.rowcount
@classmethod
async def get_data_count(cls, query_params: dict = None) -> int:
async def get_data_count(cls, query_params: dict = None,*, session: Optional[AsyncSession] = None) -> int:
session = session or cls.get_db()
if not query_params:
raise Exception("参数为空")
stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])
# stmt = cls.get_query_stmt(query_params)
session = cls.get_db()
return await session.scalar(stmt)
@classmethod
async def is_exist(cls, query_params: dict = None) -> bool:
async def is_exist(cls, query_params: dict = None):
return await cls.get_data_count(query_params) > 0

Loading…
Cancel
Save