feat: 使用with_db_session装饰器工厂控制数据库会话,不再通过FastAPI中间件

main
chenzhirong 2 months ago
parent ec6ef80152
commit 405e7c41a9

@ -1,7 +1,7 @@
import functools
import inspect
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any
from typing import Any, ParamSpec, TypeVar, Callable
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
@ -9,7 +9,10 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, Asyn
from common.constant import Constant
from common.global_enums import IsDelete
from config import settings
from entity.base_entity import DbBaseModel
P = ParamSpec('P')
T = TypeVar('T')
engine = create_async_engine(
settings.database_url,
@ -149,14 +152,33 @@ async def get_db_session():
yield session
def with_db_session(func):
@wraps(func)
async def wrapper(*args, **kwargs):
def with_db_session(session_param_name: str = "session"):
"""
一个装饰器用于为异步函数自动注入数据库会话
Args:
session_param_name: 被装饰函数中用于接收会话的参数名默认为 'session'
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
# 确保只装饰异步函数
if not inspect.iscoroutinefunction(func):
raise TypeError("`with_db_session` can only be used on async functions.")
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# 如果调用时已经手动传了 session就直接用
if session_param_name in kwargs:
return await func(*args, **kwargs)
# 否则,创建一个新 session 并注入
async with get_db_session() as session:
result = await func(db_session=session, *args, **kwargs)
return result
kwargs[session_param_name] = session
return await func(*args, **kwargs)
return wrapper
return decorator
# 关闭引擎

@ -14,9 +14,6 @@ from config import settings, show_configs
stop_event = threading.Event()
RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0"))
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
stop_event.set()
@ -30,10 +27,6 @@ if __name__ == '__main__':
f'project base: {file_utils.get_project_base_directory()}'
)
show_configs()
# import argparse
# parser = argparse.ArgumentParser()
#
# args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:

@ -18,7 +18,7 @@ def add_middleware(app: FastAPI):
max_age=2592000
)
app.add_middleware(SessionMiddleware, secret_key=secrets.token_hex(32))
app.add_middleware(DbSessionMiddleWare)
# app.add_middleware(DbSessionMiddleWare) #不再需要

@ -1,13 +1,13 @@
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
from typing import Union, Type, List, Any, TypeVar, Generic, Optional
from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from core.global_context import current_session
from entity import DbBaseModel
from entity import with_db_session
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
from utils import get_uuid
@ -19,20 +19,12 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询)
session.scalar: 直接明确获取一条数据可以直接返回无需额外处理
"""
T = TypeVar('T', bound=DbBaseModel)
T = TypeVar('T', bound=SQLModel)
class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型
@classmethod
def get_db(cls) -> AsyncSession:
"""获取当前请求的会话"""
session = current_session.get()
if session is None:
raise RuntimeError("No database session in context. "
"Make sure to use this service within a request context.")
return session
@classmethod
def get_query_stmt(cls, query_params, stmt=None, *, fields: list = None):
if stmt is None:
@ -41,8 +33,15 @@ class BaseService(Generic[T]):
else:
stmt = cls.model.select()
for key, value in query_params.items():
if hasattr(cls.model, key) and value is not None:
if value is None:
continue
if isinstance(key, str) and hasattr(cls.model, key): # 第一步:先确定 key 的类型
# 第二步:根据类型,用对应的方式处理
field = getattr(cls.model, key)
elif hasattr(key, 'model') and key.model is cls.model:
field = key
else:
continue
stmt = stmt.where(field == value)
return stmt
@ -55,7 +54,7 @@ class BaseService(Generic[T]):
for entity in entity_data:
temp = entity
if not isinstance(entity, dict):
temp = entity.to_dict()
temp = entity.model_dump()
dto_list.append(dto(**temp))
return dto_list
@ -73,9 +72,10 @@ class BaseService(Generic[T]):
return await cls.auto_page(query_stmt, query_params)
@classmethod
@with_db_session()
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None,*, session: Optional[AsyncSession] = None)->BasePageResp[T]:
session = session or cls.get_db()
dto_model_class: Type[BaseModel] = None, *, session: Optional[AsyncSession]) -> \
BasePageResp[T]:
if not query_params:
query_params = {}
if not isinstance(query_params, dict):
@ -115,8 +115,9 @@ class BaseService(Generic[T]):
})
@classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None)->List[T]:
session = session or cls.get_db()
@with_db_session()
async def get_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> List[
T]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
@ -128,12 +129,24 @@ class BaseService(Generic[T]):
query_stmt = query_stmt.order_by(field.desc())
else:
query_stmt = query_stmt.order_by(field.asc())
if query_params.get("limit", None) is not None:
query_stmt = query_stmt.limit(query_params.get("limit"))
exec_result = await session.execute(query_stmt)
return list(exec_result.scalars().all())
@classmethod
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq],*, session: Optional[AsyncSession] = None) -> List[str]:
session = session or cls.get_db()
@with_db_session()
async def get_list_json(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
List[
T]:
resp_list = await cls.get_list(query_params, session=session)
return [i.model_dump() for i in resp_list]
@classmethod
@with_db_session()
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> \
List[str]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -149,16 +162,24 @@ class BaseService(Generic[T]):
return [item["id"] for item in exec_result.scalars().all()]
@classmethod
@with_db_session()
async def save(cls, *, session: Optional[AsyncSession] = None, **kwargs) -> T:
session = session or cls.get_db()
sample_obj = cls.model(**kwargs)
session.add(sample_obj)
await session.flush()
return sample_obj
@classmethod
@with_db_session()
async def save_entity(cls, db_model: SQLModel, *, session: Optional[AsyncSession] = None) -> T:
session.add(db_model)
await session.flush()
return db_model
@classmethod
@with_db_session()
async def insert_many(cls, data_list, batch_size=100, *, session: Optional[AsyncSession] = None) -> None:
session = session or cls.get_db()
async with session:
for d in data_list:
if not d.get("id", None):
@ -168,30 +189,41 @@ class BaseService(Generic[T]):
session.add_all(data_list[i: i + batch_size])
@classmethod
@with_db_session()
async def update_by_id(cls, pid, data, *, session: Optional[AsyncSession] = None) -> int:
session = session or cls.get_db()
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
result = await session.execute(update_stmt)
return result.rowcount
@classmethod
@with_db_session()
async def update_many_by_id(cls, data_list, *, session: Optional[AsyncSession] = None) -> None:
session = session or cls.get_db()
async with session:
for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt)
@classmethod
@with_db_session()
async def get_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> T:
session = session or cls.get_db()
stmt = cls.model.select().where(cls.model.id == pid)
return await session.scalar(stmt)
@classmethod
@with_db_session()
async def get_one(cls, query_params: Union[dict, BaseQueryReq], *, session: Optional[AsyncSession] = None) -> T:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v not in [None, ""]}
query_stmt = cls.get_query_stmt(query_params)
return await session.scalar(query_stmt)
@classmethod
@with_db_session()
async def get_by_ids(cls, pids, cols=None, *, session: Optional[AsyncSession] = None) -> List[T]:
session = session or cls.get_db()
if cols:
objs = cls.model.select(*cols)
else:
@ -201,22 +233,35 @@ class BaseService(Generic[T]):
return list(result.all())
@classmethod
@with_db_session()
async def delete(cls, delete_params: dict, *, session: Optional[AsyncSession] = None) -> int:
del_stmt = cls.model.delete()
for k, v in delete_params.items():
del_stmt = del_stmt.where(getattr(cls.model, k) == v)
exec_result = await session.execute(del_stmt)
return exec_result.rowcount
@classmethod
@with_db_session()
async def delete_by_id(cls, pid, *, session: Optional[AsyncSession] = None) -> int:
session = session or cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id == pid)
exec_result = await session.execute(del_stmt)
return exec_result.rowcount
@classmethod
@with_db_session()
async def delete_by_ids(cls, pids, *, session: Optional[AsyncSession] = None) -> int:
session = session or cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt)
return result.rowcount
@classmethod
@with_db_session()
async def get_data_count(cls, query_params: dict = None, *, session: Optional[AsyncSession] = None) -> int:
session = session or cls.get_db()
if not query_params:
raise Exception("参数为空")
stmt = cls.get_query_stmt(query_params, fields=[func.count(cls.model.id)])

Loading…
Cancel
Save