feat: 支持使用自建的数据库会话

main
chenzhirong 3 months ago
parent ec16a62ee1
commit 436ae236e1

@ -5,6 +5,9 @@ from enum import IntEnum, StrEnum
class IsDelete(IntEnum): class IsDelete(IntEnum):
NO_DELETE = 0 NO_DELETE = 0
DELETE = 1 DELETE = 1
class UserRoleEnum(StrEnum):
USER="user"
ADMIN="ADMIN"
class LLMType(StrEnum): class LLMType(StrEnum):
CHAT = 'chat' CHAT = 'chat'
@ -12,4 +15,4 @@ class LLMType(StrEnum):
SPEECH2TEXT = 'speech2text' SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text' IMAGE2TEXT = 'image2text'
RERANK = 'rerank' RERANK = 'rerank'
TTS = 'tts' TTS = 'tts'

@ -1,5 +1,6 @@
from sqlmodel import SQLModel from sqlmodel import SQLModel
from common.global_enums import UserRoleEnum
from entity import DbBaseModel, engine from entity import DbBaseModel, engine
@ -13,3 +14,4 @@ class User(DbBaseModel, table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写 __tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str username: str
password: str password: str
user_role: UserRoleEnum

@ -17,4 +17,6 @@ class DbSessionMiddleWare(BaseHTTPMiddleware):
finally: finally:
# 重置上下文变量 # 重置上下文变量
current_session.reset(token) current_session.reset(token)
# 无论成功与否,都必须关闭会话
await session.close()
return response return response

@ -1,4 +1,4 @@
from typing import Union, Type, List, Any, TypeVar, Generic from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
@ -21,7 +21,6 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额
""" """
T = TypeVar('T', bound=DbBaseModel) T = TypeVar('T', bound=DbBaseModel)
class BaseService(Generic[T]): class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型 model: Type[T] # 子类必须指定模型
@ -35,17 +34,17 @@ class BaseService(Generic[T]):
return session return session
@classmethod @classmethod
def get_query_stmt(cls, query_params, sessions=None, *, fields: list = None): def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
if sessions is None: if stmt is None:
if fields: if fields:
sessions = cls.model.select(*fields) stmt = cls.model.select(*fields)
else: else:
sessions = cls.model.select() stmt = cls.model.select()
for key, value in query_params.items(): for key, value in query_params.items():
if hasattr(cls.model, key) and value is not None: if hasattr(cls.model, key) and value is not None:
field = getattr(cls.model, key) field = getattr(cls.model, key)
sessions = sessions.where(field == value) stmt = stmt.where(field == value)
return sessions return stmt
@classmethod @classmethod
def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[ def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[
@ -66,16 +65,17 @@ class BaseService(Generic[T]):
pass pass
@classmethod @classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]: async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp[T]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v not in [None,""]}
query_stmt = cls.get_query_stmt(query_params) query_stmt = cls.get_query_stmt(query_params)
return await cls.auto_page(query_stmt, query_params) return await cls.auto_page(query_stmt, query_params)
@classmethod @classmethod
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]: dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]:
session = session or cls.get_db()
if not query_params: if not query_params:
query_params = {} query_params = {}
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
@ -84,7 +84,6 @@ class BaseService(Generic[T]):
page_size = query_params.get("page_size", 12) page_size = query_params.get("page_size", 12)
sort = query_params.get("sort", "desc") sort = query_params.get("sort", "desc")
orderby = query_params.get("orderby", "created_time") orderby = query_params.get("orderby", "created_time")
session = cls.get_db()
data_count = None data_count = None
if data_count == 0: if data_count == 0:
return BasePageResp(**{ return BasePageResp(**{
@ -116,10 +115,11 @@ class BaseService(Generic[T]):
}) })
@classmethod @classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]: async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]:
session = session or cls.get_db()
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v not in [None,""]}
sort = query_params.get("sort", "desc") sort = query_params.get("sort", "desc")
orderby = query_params.get("orderby", "created_time") orderby = query_params.get("orderby", "created_time")
query_stmt = cls.get_query_stmt(query_params) query_stmt = cls.get_query_stmt(query_params)
@ -128,12 +128,12 @@ class BaseService(Generic[T]):
query_stmt = query_stmt.order_by(field.desc()) query_stmt = query_stmt.order_by(field.desc())
else: else:
query_stmt = query_stmt.order_by(field.asc()) query_stmt = query_stmt.order_by(field.asc())
session = cls.get_db()
exec_result = await session.execute(query_stmt) exec_result = await session.execute(query_stmt)
return list(exec_result.scalars().all()) return list(exec_result.scalars().all())
@classmethod @classmethod
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[str]: async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]:
session = session or cls.get_db()
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -145,21 +145,21 @@ class BaseService(Generic[T]):
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc()) query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc())
else: else:
query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc()) query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc())
session = cls.get_db()
exec_result = await session.execute(query_stmt) exec_result = await session.execute(query_stmt)
return [item["id"] for item in exec_result.scalars().all()] return [item["id"] for item in exec_result.scalars().all()]
@classmethod @classmethod
async def save(cls, **kwargs) -> T: async def save(cls,*, session: Optional[AsyncSession] = None, **kwargs)->T:
session = session or cls.get_db()
sample_obj = cls.model(**kwargs) sample_obj = cls.model(**kwargs)
session = cls.get_db()
session.add(sample_obj) session.add(sample_obj)
await session.flush() await session.flush()
return sample_obj return sample_obj
@classmethod @classmethod
async def insert_many(cls, data_list, batch_size=100) -> None: async def insert_many(cls, data_list, batch_size=100,*, session: Optional[AsyncSession] = None)->None:
async with cls.get_db() as session: session = session or cls.get_db()
async with session:
for d in data_list: for d in data_list:
if not d.get("id", None): if not d.get("id", None):
d["id"] = get_uuid() d["id"] = get_uuid()
@ -168,59 +168,61 @@ class BaseService(Generic[T]):
session.add_all(data_list[i: i + batch_size]) session.add_all(data_list[i: i + batch_size])
@classmethod @classmethod
async def update_by_id(cls, pid, data) -> int: async def update_by_id(cls, pid, data,*, session: Optional[AsyncSession] = None)-> int:
session = session or cls.get_db()
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
session = cls.get_db()
result = await session.execute(update_stmt) result = await session.execute(update_stmt)
return result.rowcount return result.rowcount
@classmethod @classmethod
async def update_many_by_id(cls, data_list) -> None: async def update_many_by_id(cls, data_list,*, session: Optional[AsyncSession] = None)->None:
async with cls.get_db() as session: session = session or cls.get_db()
async with session:
for data in data_list: for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt) await session.execute(stmt)
@classmethod @classmethod
async def get_by_id(cls, pid) -> T: async def get_by_id(cls, pid,*, session: Optional[AsyncSession] = None)->T:
session = session or cls.get_db()
stmt = cls.model.select().where(cls.model.id == pid) stmt = cls.model.select().where(cls.model.id == pid)
session = cls.get_db()
return await session.scalar(stmt) return await session.scalar(stmt)
@classmethod @classmethod
async def get_by_ids(cls, pids, cols=None) -> List[T]: async def get_by_ids(cls, pids, cols=None,*, session: Optional[AsyncSession] = None)->List[T]:
session = session or cls.get_db()
if cols: if cols:
objs = cls.model.select(*cols) objs = cls.model.select(*cols)
else: else:
objs = cls.model.select() objs = cls.model.select()
stmt = objs.where(cls.model.id.in_(pids)) stmt = objs.where(cls.model.id.in_(pids))
session = cls.get_db()
result = await session.scalars(stmt) result = await session.scalars(stmt)
return list(result.all()) return list(result.all())
@classmethod @classmethod
async def delete_by_id(cls, pid) -> int: async def delete_by_id(cls, pid,*, session: Optional[AsyncSession] = None)-> int:
session = session or cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id == pid) del_stmt = cls.model.delete().where(cls.model.id == pid)
session = cls.get_db()
exec_result = await session.execute(del_stmt) exec_result = await session.execute(del_stmt)
return exec_result.rowcount return exec_result.rowcount
@classmethod @classmethod
async def delete_by_ids(cls, pids) -> int: async def delete_by_ids(cls, pids,*, session: Optional[AsyncSession] = None)-> int:
session = cls.get_db() session = session or cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt) result = await session.execute(del_stmt)
return result.rowcount return result.rowcount
@classmethod @classmethod
async def get_data_count(cls, query_params: dict = None) -> int: async def get_data_count(cls, query_params: dict = None,*, session: Optional[AsyncSession] = None) -> int:
session = session or cls.get_db()
if not query_params: if not query_params:
raise Exception("参数为空") raise Exception("参数为空")
stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)]) stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])
# stmt = cls.get_query_stmt(query_params) # stmt = cls.get_query_stmt(query_params)
session = cls.get_db()
return await session.scalar(stmt) return await session.scalar(stmt)
@classmethod @classmethod
async def is_exist(cls, query_params: dict = None) -> bool: async def is_exist(cls, query_params: dict = None):
return await cls.get_data_count(query_params) > 0 return await cls.get_data_count(query_params) > 0

Loading…
Cancel
Save