feat: 使用with_db_session装饰器工厂控制数据库会话,不再通过FastAPI中间件

main
chenzhirong 2 months ago
parent ec6ef80152
commit 405e7c41a9

@ -1,7 +1,7 @@
import functools
import inspect import inspect
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import wraps from typing import Any, ParamSpec, TypeVar, Callable
from typing import Any
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
@ -9,7 +9,10 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, Asyn
from common.constant import Constant from common.constant import Constant
from common.global_enums import IsDelete from common.global_enums import IsDelete
from config import settings from config import settings
from entity.base_entity import DbBaseModel from entity.base_entity import DbBaseModel
P = ParamSpec('P')
T = TypeVar('T')
engine = create_async_engine( engine = create_async_engine(
settings.database_url, settings.database_url,
@ -149,14 +152,33 @@ async def get_db_session():
yield session yield session
def with_db_session(func):
@wraps(func) def with_db_session(session_param_name: str = "session"):
async def wrapper(*args, **kwargs): """
一个装饰器用于为异步函数自动注入数据库会话
Args:
session_param_name: 被装饰函数中用于接收会话的参数名默认为 'session'
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
# 确保只装饰异步函数
if not inspect.iscoroutinefunction(func):
raise TypeError("`with_db_session` can only be used on async functions.")
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# 如果调用时已经手动传了 session就直接用
if session_param_name in kwargs:
return await func(*args, **kwargs)
# 否则,创建一个新 session 并注入
async with get_db_session() as session: async with get_db_session() as session:
result = await func(db_session=session, *args, **kwargs) kwargs[session_param_name] = session
return result return await func(*args, **kwargs)
return wrapper return wrapper
return decorator
# 关闭引擎 # 关闭引擎

@ -14,9 +14,6 @@ from config import settings, show_configs
stop_event = threading.Event() stop_event = threading.Event()
RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0"))
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
stop_event.set() stop_event.set()
@ -30,10 +27,6 @@ if __name__ == '__main__':
f'project base: {file_utils.get_project_base_directory()}' f'project base: {file_utils.get_project_base_directory()}'
) )
show_configs() show_configs()
# import argparse
# parser = argparse.ArgumentParser()
#
# args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
try: try:

@ -18,7 +18,7 @@ def add_middleware(app: FastAPI):
max_age=2592000 max_age=2592000
) )
app.add_middleware(SessionMiddleware, secret_key=secrets.token_hex(32)) app.add_middleware(SessionMiddleware, secret_key=secrets.token_hex(32))
app.add_middleware(DbSessionMiddleWare) # app.add_middleware(DbSessionMiddleWare) #不再需要

@ -1,13 +1,13 @@
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional from typing import Union, Type, List, Any, TypeVar, Generic, Optional
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from core.global_context import current_session from entity import with_db_session
from entity import DbBaseModel
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
from utils import get_uuid from utils import get_uuid
@ -19,20 +19,12 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询)
session.scalar: 直接明确获取一条数据可以直接返回无需额外处理 session.scalar: 直接明确获取一条数据可以直接返回无需额外处理
""" """
T = TypeVar('T', bound=DbBaseModel) T = TypeVar('T', bound=SQLModel)
class BaseService(Generic[T]): class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型 model: Type[T] # 子类必须指定模型
@classmethod
def get_db(cls) -> AsyncSession:
"""获取当前请求的会话"""
session = current_session.get()
if session is None:
raise RuntimeError("No database session in context. "
"Make sure to use this service within a request context.")
return session
@classmethod @classmethod
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None): def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
if stmt is None: if stmt is None:
@ -41,8 +33,15 @@ class BaseService(Generic[T]):
else: else:
stmt = cls.model.select() stmt = cls.model.select()
for key, value in query_params.items(): for key, value in query_params.items():
if hasattr(cls.model, key) and value is not None: if value is None:
continue
if isinstance(key, str) and hasattr(cls.model, key): # 第一步:先确定 key 的类型
# 第二步:根据类型,用对应的方式处理
field = getattr(cls.model, key) field = getattr(cls.model, key)
elif hasattr(key, 'model') and key.model is cls.model:
field = key
else:
continue
stmt = stmt.where(field == value) stmt = stmt.where(field == value)
return stmt return stmt
@ -55,7 +54,7 @@ class BaseService(Generic[T]):
for entity in entity_data: for entity in entity_data:
temp = entity temp = entity
if not isinstance(entity, dict): if not isinstance(entity, dict):
temp = entity.to_dict() temp = entity.model_dump()
dto_list.append(dto(**temp)) dto_list.append(dto(**temp))
return dto_list return dto_list
@ -65,17 +64,18 @@ class BaseService(Generic[T]):
pass pass
@classmethod @classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp[T]: async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]) -> BasePageResp[T]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v not in [None,""]} query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
query_stmt = cls.get_query_stmt(query_params) query_stmt = cls.get_query_stmt(query_params)
return await cls.auto_page(query_stmt, query_params) return await cls.auto_page(query_stmt, query_params)
@classmethod @classmethod
@with_db_session()
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]: dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> \
session = session or cls.get_db() BasePageResp[T]:
if not query_params: if not query_params:
query_params = {} query_params = {}
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
@ -115,11 +115,12 @@ class BaseService(Generic[T]):
}) })
@classmethod @classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]: @with_db_session()
session = session or cls.get_db() async def get_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> List[
T]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v not in [None,""]} query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
sort = query_params.get("sort", "desc") sort = query_params.get("sort", "desc")
orderby = query_params.get("orderby", "created_time") orderby = query_params.get("orderby", "created_time")
query_stmt = cls.get_query_stmt(query_params) query_stmt = cls.get_query_stmt(query_params)
@ -128,12 +129,24 @@ class BaseService(Generic[T]):
query_stmt = query_stmt.order_by(field.desc()) query_stmt = query_stmt.order_by(field.desc())
else: else:
query_stmt = query_stmt.order_by(field.asc()) query_stmt = query_stmt.order_by(field.asc())
if query_params.get("limit", None) is not None:
query_stmt = query_stmt.limit(query_params.get("limit"))
exec_result = await session.execute(query_stmt) exec_result = await session.execute(query_stmt)
return list(exec_result.scalars().all()) return list(exec_result.scalars().all())
@classmethod @classmethod
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]: @with_db_session()
session = session or cls.get_db() async def get_list_json(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
List[
T]:
resp_list = await cls.get_list(query_params, session=session)
return [i.model_dump() for i in resp_list]
@classmethod
@with_db_session()
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
List[str]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -149,16 +162,24 @@ class BaseService(Generic[T]):
return [item["id"] for item in exec_result.scalars().all()] return [item["id"] for item in exec_result.scalars().all()]
@classmethod @classmethod
async def save(cls,*, session: Optional[AsyncSession] = None, **kwargs)->T: @with_db_session()
session = session or cls.get_db() async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T:
sample_obj = cls.model(**kwargs) sample_obj = cls.model(**kwargs)
session.add(sample_obj) session.add(sample_obj)
await session.flush() await session.flush()
return sample_obj return sample_obj
@classmethod @classmethod
async def insert_many(cls, data_list, batch_size=100,*, session: Optional[AsyncSession] = None)->None: @with_db_session()
session = session or cls.get_db() async def save_entity(cls, db_model: SQLModel, *, session: Optional[AsyncSession] = None) -> T:
session.add(db_model)
await session.flush()
return db_model
@classmethod
@with_db_session()
async def insert_many(cls, data_list, batch_size=100, *, session: Optional[AsyncSession] = None) -> None:
async with session: async with session:
for d in data_list: for d in data_list:
if not d.get("id", None): if not d.get("id", None):
@ -168,30 +189,41 @@ class BaseService(Generic[T]):
session.add_all(data_list[i: i + batch_size]) session.add_all(data_list[i: i + batch_size])
@classmethod @classmethod
async def update_by_id(cls, pid, data,*, session: Optional[AsyncSession] = None)-> int: @with_db_session()
session = session or cls.get_db() async def update_by_id(cls, pid, data, *, session: Optional[AsyncSession] = None) -> int:
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
result = await session.execute(update_stmt) result = await session.execute(update_stmt)
return result.rowcount return result.rowcount
@classmethod @classmethod
async def update_many_by_id(cls, data_list,*, session: Optional[AsyncSession] = None)->None: @with_db_session()
session = session or cls.get_db() async def update_many_by_id(cls, data_list, *, session: Optional[AsyncSession] = None) -> None:
async with session: async with session:
for data in data_list: for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt) await session.execute(stmt)
@classmethod @classmethod
async def get_by_id(cls, pid,*, session: Optional[AsyncSession] = None)->T: @with_db_session()
session = session or cls.get_db() async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T:
stmt = cls.model.select().where(cls.model.id == pid) stmt = cls.model.select().where(cls.model.id == pid)
return await session.scalar(stmt) return await session.scalar(stmt)
@classmethod
@with_db_session()
async def get_one(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> T:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
query_stmt = cls.get_query_stmt(query_params)
return await session.scalar(query_stmt)
@classmethod @classmethod
async def get_by_ids(cls, pids, cols=None,*, session: Optional[AsyncSession] = None)->List[T]: @with_db_session()
session = session or cls.get_db() async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]:
if cols: if cols:
objs = cls.model.select(*cols) objs = cls.model.select(*cols)
else: else:
@ -201,22 +233,35 @@ class BaseService(Generic[T]):
return list(result.all()) return list(result.all())
@classmethod @classmethod
async def delete_by_id(cls, pid,*, session: Optional[AsyncSession] = None)-> int: @with_db_session()
session = session or cls.get_db() async def delete(cls, delete_params: dict, *, session: Optional[AsyncSession] = None) -> int:
del_stmt = cls.model.delete()
for k, v in delete_params.items():
del_stmt = del_stmt.where(getattr(cls.model, k) == v)
exec_result = await session.execute(del_stmt)
return exec_result.rowcount
@classmethod
@with_db_session()
async def delete_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> int:
del_stmt = cls.model.delete().where(cls.model.id == pid) del_stmt = cls.model.delete().where(cls.model.id == pid)
exec_result = await session.execute(del_stmt) exec_result = await session.execute(del_stmt)
return exec_result.rowcount return exec_result.rowcount
@classmethod @classmethod
async def delete_by_ids(cls, pids,*, session: Optional[AsyncSession] = None)-> int: @with_db_session()
session = session or cls.get_db() async def delete_by_ids(cls, pids, *, session: Optional[AsyncSession] = None) -> int:
del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt) result = await session.execute(del_stmt)
return result.rowcount return result.rowcount
@classmethod @classmethod
async def get_data_count(cls, query_params: dict = None,*, session: Optional[AsyncSession] = None) -> int: @with_db_session()
session = session or cls.get_db() async def get_data_count(cls, query_params: dict = None, *, session: Optional[AsyncSession] = None) -> int:
if not query_params: if not query_params:
raise Exception("参数为空") raise Exception("参数为空")
stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)]) stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])

Loading…
Cancel
Save