diff --git a/common/global_enums.py b/common/global_enums.py index 6a5c351..cc2fc3a 100644 --- a/common/global_enums.py +++ b/common/global_enums.py @@ -5,6 +5,9 @@ from enum import IntEnum, StrEnum class IsDelete(IntEnum): NO_DELETE = 0 DELETE = 1 +class UserRoleEnum(StrEnum): + USER="user" + ADMIN="ADMIN" class LLMType(StrEnum): CHAT = 'chat' @@ -12,4 +15,4 @@ class LLMType(StrEnum): SPEECH2TEXT = 'speech2text' IMAGE2TEXT = 'image2text' RERANK = 'rerank' - TTS = 'tts' \ No newline at end of file + TTS = 'tts' diff --git a/entity/db_models.py b/entity/db_models.py index 920e426..aa7e06c 100644 --- a/entity/db_models.py +++ b/entity/db_models.py @@ -1,5 +1,6 @@ from sqlmodel import SQLModel +from common.global_enums import UserRoleEnum from entity import DbBaseModel, engine @@ -13,3 +14,4 @@ class User(DbBaseModel, table=True): __tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写 username: str password: str + user_role: UserRoleEnum diff --git a/middleware/db_session.py b/middleware/db_session.py index 50ac395..39286cf 100644 --- a/middleware/db_session.py +++ b/middleware/db_session.py @@ -17,4 +17,6 @@ class DbSessionMiddleWare(BaseHTTPMiddleware): finally: # 重置上下文变量 current_session.reset(token) + # 无论成功与否,都必须关闭会话 + await session.close() return response \ No newline at end of file diff --git a/service/base_service.py b/service/base_service.py index 2d4f3a0..ffc205c 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -1,4 +1,4 @@ -from typing import Union, Type, List, Any, TypeVar, Generic +from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional from fastapi_pagination import Params from fastapi_pagination.ext.sqlalchemy import paginate @@ -21,7 +21,6 @@ session.scalar: 直接明确获取一条数据,可以直接返回,无需额 """ T = TypeVar('T', bound=DbBaseModel) - class BaseService(Generic[T]): model: Type[T] # 子类必须指定模型 @@ -35,17 +34,17 @@ class BaseService(Generic[T]): return session @classmethod - def get_query_stmt(cls, query_params, sessions=None, *, fields: list = None): - if sessions is None: + def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None): + if stmt is None: if fields: - sessions = cls.model.select(*fields) + stmt = cls.model.select(*fields) else: - sessions = cls.model.select() + stmt = cls.model.select() for key, value in query_params.items(): if hasattr(cls.model, key) and value is not None: field = getattr(cls.model, key) - sessions = sessions.where(field == value) - return sessions + stmt = stmt.where(field == value) + return stmt @classmethod def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[ @@ -66,16 +65,17 @@ class BaseService(Generic[T]): pass @classmethod - async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]: + async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp[T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() - query_params = {k: v for k, v in query_params.items() if v is not None} + query_params = {k: v for k, v in query_params.items() if v not in [None,""]} query_stmt = cls.get_query_stmt(query_params) return await cls.auto_page(query_stmt, query_params) @classmethod async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, - dto_model_class: Type[BaseModel] = None) -> BasePageResp[T]: + dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]: + session = session or cls.get_db() if not query_params: query_params = {} if not isinstance(query_params, dict): @@ -84,7 +84,6 @@ class BaseService(Generic[T]): page_size = query_params.get("page_size", 12) sort = query_params.get("sort", "desc") orderby = query_params.get("orderby", "created_time") - session = cls.get_db() data_count = None if data_count == 0: return BasePageResp(**{ @@ -116,10 +115,11 @@ class BaseService(Generic[T]): }) @classmethod - async def get_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[T]: + async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]: + session = session or cls.get_db() if not isinstance(query_params, dict): query_params = query_params.model_dump() - query_params = {k: v for k, v in query_params.items() if v is not None} + query_params = {k: v for k, v in query_params.items() if v not in [None,""]} sort = query_params.get("sort", "desc") orderby = query_params.get("orderby", "created_time") query_stmt = cls.get_query_stmt(query_params) @@ -128,12 +128,12 @@ class BaseService(Generic[T]): query_stmt = query_stmt.order_by(field.desc()) else: query_stmt = query_stmt.order_by(field.asc()) - session = cls.get_db() exec_result = await session.execute(query_stmt) return list(exec_result.scalars().all()) @classmethod - async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[str]: + async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]: + session = session or cls.get_db() if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -145,21 +145,21 @@ class BaseService(Generic[T]): query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc()) else: query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc()) - session = cls.get_db() exec_result = await session.execute(query_stmt) return [item["id"] for item in exec_result.scalars().all()] @classmethod - async def save(cls, **kwargs) -> T: + async def save(cls,*, session: Optional[AsyncSession] = None, **kwargs)->T: + session = session or cls.get_db() sample_obj = cls.model(**kwargs) - session = cls.get_db() session.add(sample_obj) await session.flush() return sample_obj @classmethod - async def insert_many(cls, data_list, batch_size=100) -> None: - async with cls.get_db() as session: + async def insert_many(cls, data_list, batch_size=100,*, session: Optional[AsyncSession] = None)->None: + session = session or cls.get_db() + async with session: for d in data_list: if not d.get("id", None): d["id"] = get_uuid() @@ -168,59 +168,61 @@ class BaseService(Generic[T]): session.add_all(data_list[i: i + batch_size]) @classmethod - async def update_by_id(cls, pid, data) -> int: + async def update_by_id(cls, pid, data,*, session: Optional[AsyncSession] = None)-> int: + session = session or cls.get_db() update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) - session = cls.get_db() result = await session.execute(update_stmt) return result.rowcount @classmethod - async def update_many_by_id(cls, data_list) -> None: - async with cls.get_db() as session: + async def update_many_by_id(cls, data_list,*, session: Optional[AsyncSession] = None)->None: + session = session or cls.get_db() + async with session: for data in data_list: stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) await session.execute(stmt) @classmethod - async def get_by_id(cls, pid) -> T: + async def get_by_id(cls, pid,*, session: Optional[AsyncSession] = None)->T: + session = session or cls.get_db() stmt = cls.model.select().where(cls.model.id == pid) - session = cls.get_db() return await session.scalar(stmt) + @classmethod - async def get_by_ids(cls, pids, cols=None) -> List[T]: + async def get_by_ids(cls, pids, cols=None,*, session: Optional[AsyncSession] = None)->List[T]: + session = session or cls.get_db() if cols: objs = cls.model.select(*cols) else: objs = cls.model.select() stmt = objs.where(cls.model.id.in_(pids)) - session = cls.get_db() result = await session.scalars(stmt) return list(result.all()) @classmethod - async def delete_by_id(cls, pid) -> int: + async def delete_by_id(cls, pid,*, session: Optional[AsyncSession] = None)-> int: + session = session or cls.get_db() del_stmt = cls.model.delete().where(cls.model.id == pid) - session = cls.get_db() exec_result = await session.execute(del_stmt) return exec_result.rowcount @classmethod - async def delete_by_ids(cls, pids) -> int: - session = cls.get_db() + async def delete_by_ids(cls, pids,*, session: Optional[AsyncSession] = None)-> int: + session = session or cls.get_db() del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) result = await session.execute(del_stmt) return result.rowcount @classmethod - async def get_data_count(cls, query_params: dict = None) -> int: + async def get_data_count(cls, query_params: dict = None,*, session: Optional[AsyncSession] = None) -> int: + session = session or cls.get_db() if not query_params: raise Exception("参数为空") stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)]) # stmt = cls.get_query_stmt(query_params) - session = cls.get_db() return await session.scalar(stmt) @classmethod - async def is_exist(cls, query_params: dict = None) -> bool: + async def is_exist(cls, query_params: dict = None): return await cls.get_data_count(query_params) > 0