From 405e7c41a95c17c40bfbb6680a43dfc90997b221 Mon Sep 17 00:00:00 2001 From: chenzhirong <826531489@qq.com> Date: Thu, 27 Nov 2025 17:55:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BD=BF=E7=94=A8with=5Fdb=5Fsession?= =?UTF-8?q?=E8=A3=85=E9=A5=B0=E5=99=A8=E5=B7=A5=E5=8E=82=E6=8E=A7=E5=88=B6?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E4=BC=9A=E8=AF=9D=EF=BC=8C=E4=B8=8D?= =?UTF-8?q?=E5=86=8D=E9=80=9A=E8=BF=87FastAPI=E4=B8=AD=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- entity/__init__.py | 40 +++++++++--- main.py | 7 --- middleware/__init__.py | 2 +- service/base_service.py | 131 +++++++++++++++++++++++++++------------- 4 files changed, 120 insertions(+), 60 deletions(-) diff --git a/entity/__init__.py b/entity/__init__.py index a7a48ae..098388f 100644 --- a/entity/__init__.py +++ b/entity/__init__.py @@ -1,7 +1,7 @@ +import functools import inspect from contextlib import asynccontextmanager -from functools import wraps -from typing import Any +from typing import Any, ParamSpec, TypeVar, Callable from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession @@ -9,7 +9,10 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, Asyn from common.constant import Constant from common.global_enums import IsDelete from config import settings + from entity.base_entity import DbBaseModel +P = ParamSpec('P') +T = TypeVar('T') engine = create_async_engine( settings.database_url, @@ -149,14 +152,33 @@ async def get_db_session(): yield session -def with_db_session(func): - @wraps(func) - async def wrapper(*args, **kwargs): - async with get_db_session() as session: - result = await func(db_session=session, *args, **kwargs) - return result - return wrapper +def with_db_session(session_param_name: str = "session"): + """ + 一个装饰器,用于为异步函数自动注入数据库会话。 + + Args: + session_param_name: 被装饰函数中,用于接收会话的参数名,默认为 'session'。 + """ + def decorator(func: Callable[P, T]) -> Callable[P, T]: + # 确保只装饰异步函数 + if not inspect.iscoroutinefunction(func): + raise TypeError("`with_db_session` can only be used on async functions.") + + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + # 如果调用时已经手动传了 session,就直接用 + if session_param_name in kwargs: + return await func(*args, **kwargs) + + # 否则,创建一个新 session 并注入 + async with get_db_session() as session: + kwargs[session_param_name] = session + return await func(*args, **kwargs) + + return wrapper + return decorator + # 关闭引擎 diff --git a/main.py b/main.py index c6678d9..b9f12ef 100644 --- a/main.py +++ b/main.py @@ -14,9 +14,6 @@ from config import settings, show_configs stop_event = threading.Event() -RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0")) - - def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") stop_event.set() @@ -30,10 +27,6 @@ if __name__ == '__main__': f'project base: {file_utils.get_project_base_directory()}' ) show_configs() - # import argparse - # parser = argparse.ArgumentParser() - # - # args = parser.parse_args() signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: diff --git a/middleware/__init__.py b/middleware/__init__.py index ba50d88..e1d2071 100644 --- a/middleware/__init__.py +++ b/middleware/__init__.py @@ -18,7 +18,7 @@ def add_middleware(app: FastAPI): max_age=2592000 ) app.add_middleware(SessionMiddleware, secret_key=secrets.token_hex(32)) - app.add_middleware(DbSessionMiddleWare) + # app.add_middleware(DbSessionMiddleWare) #不再需要 diff --git a/service/base_service.py b/service/base_service.py index ffc205c..7b4b6cd 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -1,13 +1,13 @@ -from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional +from typing import Union, Type, List, Any, TypeVar, Generic, Optional from fastapi_pagination import Params from fastapi_pagination.ext.sqlalchemy import paginate from pydantic import BaseModel from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import SQLModel -from core.global_context import current_session -from entity import DbBaseModel +from entity import with_db_session from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from utils import get_uuid @@ -19,20 +19,12 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询) session.scalar: 直接明确获取一条数据,可以直接返回,无需额外处理 """ -T = TypeVar('T', bound=DbBaseModel) +T = TypeVar('T', bound=SQLModel) + class BaseService(Generic[T]): model: Type[T] # 子类必须指定模型 - @classmethod - def get_db(cls) -> AsyncSession: - """获取当前请求的会话""" - session = current_session.get() - if session is None: - raise RuntimeError("No database session in context. " - "Make sure to use this service within a request context.") - return session - @classmethod def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None): if stmt is None: @@ -41,9 +33,16 @@ class BaseService(Generic[T]): else: stmt = cls.model.select() for key, value in query_params.items(): - if hasattr(cls.model, key) and value is not None: + if value is None: + continue + if isinstance(key, str) and hasattr(cls.model, key): # 第一步:先确定 key 的类型 + # 第二步:根据类型,用对应的方式处理 field = getattr(cls.model, key) - stmt = stmt.where(field == value) + elif hasattr(key, 'model') and key.model is cls.model: + field = key + else: + continue + stmt = stmt.where(field == value) return stmt @classmethod @@ -55,7 +54,7 @@ class BaseService(Generic[T]): for entity in entity_data: temp = entity if not isinstance(entity, dict): - temp = entity.to_dict() + temp = entity.model_dump() dto_list.append(dto(**temp)) return dto_list @@ -65,17 +64,18 @@ class BaseService(Generic[T]): pass @classmethod - async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp[T]: + async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() - query_params = {k: v for k, v in query_params.items() if v not in [None,""]} + query_params = {k: v for k, v in query_params.items() if v not in [None, ""]} query_stmt = cls.get_query_stmt(query_params) return await cls.auto_page(query_stmt, query_params) @classmethod + @with_db_session() async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, - dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]: - session = session or cls.get_db() + dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> \ + BasePageResp[T]: if not query_params: query_params = {} if not isinstance(query_params, dict): @@ -115,11 +115,12 @@ class BaseService(Generic[T]): }) @classmethod - async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]: - session = session or cls.get_db() + @with_db_session() + async def get_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> List[ + T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() - query_params = {k: v for k, v in query_params.items() if v not in [None,""]} + query_params = {k: v for k, v in query_params.items() if v not in [None, ""]} sort = query_params.get("sort", "desc") orderby = query_params.get("orderby", "created_time") query_stmt = cls.get_query_stmt(query_params) @@ -128,12 +129,24 @@ class BaseService(Generic[T]): query_stmt = query_stmt.order_by(field.desc()) else: query_stmt = query_stmt.order_by(field.asc()) + if query_params.get("limit", None) is not None: + query_stmt = query_stmt.limit(query_params.get("limit")) exec_result = await session.execute(query_stmt) return list(exec_result.scalars().all()) @classmethod - async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]: - session = session or cls.get_db() + @with_db_session() + async def get_list_json(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \ + List[ + T]: + resp_list = await cls.get_list(query_params, session=session) + + return [i.model_dump() for i in resp_list] + + @classmethod + @with_db_session() + async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \ + List[str]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -149,16 +162,24 @@ class BaseService(Generic[T]): return [item["id"] for item in exec_result.scalars().all()] @classmethod - async def save(cls,*, session: Optional[AsyncSession] = None, **kwargs)->T: - session = session or cls.get_db() + @with_db_session() + async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T: + sample_obj = cls.model(**kwargs) session.add(sample_obj) await session.flush() return sample_obj @classmethod - async def insert_many(cls, data_list, batch_size=100,*, session: Optional[AsyncSession] = None)->None: - session = session or cls.get_db() + @with_db_session() + async def save_entity(cls, db_model: SQLModel, *, session: Optional[AsyncSession] = None) -> T: + session.add(db_model) + await session.flush() + return db_model + + @classmethod + @with_db_session() + async def insert_many(cls, data_list, batch_size=100, *, session: Optional[AsyncSession] = None) -> None: async with session: for d in data_list: if not d.get("id", None): @@ -168,30 +189,41 @@ class BaseService(Generic[T]): session.add_all(data_list[i: i + batch_size]) @classmethod - async def update_by_id(cls, pid, data,*, session: Optional[AsyncSession] = None)-> int: - session = session or cls.get_db() + @with_db_session() + async def update_by_id(cls, pid, data, *, session: Optional[AsyncSession] = None) -> int: update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) result = await session.execute(update_stmt) return result.rowcount @classmethod - async def update_many_by_id(cls, data_list,*, session: Optional[AsyncSession] = None)->None: - session = session or cls.get_db() + @with_db_session() + async def update_many_by_id(cls, data_list, *, session: Optional[AsyncSession] = None) -> None: async with session: for data in data_list: stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) await session.execute(stmt) @classmethod - async def get_by_id(cls, pid,*, session: Optional[AsyncSession] = None)->T: - session = session or cls.get_db() + @with_db_session() + async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T: + stmt = cls.model.select().where(cls.model.id == pid) return await session.scalar(stmt) + @classmethod + @with_db_session() + async def get_one(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> T: + + if not isinstance(query_params, dict): + query_params = query_params.model_dump() + query_params = {k: v for k, v in query_params.items() if v not in [None, ""]} + query_stmt = cls.get_query_stmt(query_params) + return await session.scalar(query_stmt) @classmethod - async def get_by_ids(cls, pids, cols=None,*, session: Optional[AsyncSession] = None)->List[T]: - session = session or cls.get_db() + @with_db_session() + async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]: + if cols: objs = cls.model.select(*cols) else: @@ -201,22 +233,35 @@ class BaseService(Generic[T]): return list(result.all()) @classmethod - async def delete_by_id(cls, pid,*, session: Optional[AsyncSession] = None)-> int: - session = session or cls.get_db() + @with_db_session() + async def delete(cls, delete_params: dict, *, session: Optional[AsyncSession] = None) -> int: + + del_stmt = cls.model.delete() + for k, v in delete_params.items(): + del_stmt = del_stmt.where(getattr(cls.model, k) == v) + exec_result = await session.execute(del_stmt) + return exec_result.rowcount + + @classmethod + @with_db_session() + async def delete_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> int: + del_stmt = cls.model.delete().where(cls.model.id == pid) exec_result = await session.execute(del_stmt) return exec_result.rowcount @classmethod - async def delete_by_ids(cls, pids,*, session: Optional[AsyncSession] = None)-> int: - session = session or cls.get_db() + @with_db_session() + async def delete_by_ids(cls, pids, *, session: Optional[AsyncSession] = None) -> int: + del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) result = await session.execute(del_stmt) return result.rowcount @classmethod - async def get_data_count(cls, query_params: dict = None,*, session: Optional[AsyncSession] = None) -> int: - session = session or cls.get_db() + @with_db_session() + async def get_data_count(cls, query_params: dict = None, *, session: Optional[AsyncSession] = None) -> int: + if not query_params: raise Exception("参数为空") stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])