You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
from contextvars import ContextVar
|
|
from contextvars import ContextVar
|
|
from typing import Optional, TypeVar, Generic, Type
|
|
|
|
from fastapi import Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from entity import AsyncSessionLocal
|
|
from entity.user import User
|
|
|
|
# 1. 创建上下文变量存储当前会话
|
|
current_session: ContextVar[Optional[AsyncSession]] = ContextVar("current_session", default=None)
|
|
|
|
# 3. 中间件:管理请求生命周期和会话
|
|
async def db_session_middleware(request: Request, call_next):
|
|
async with AsyncSessionLocal() as session:
|
|
# 设置会话到上下文变量
|
|
token = current_session.set(session)
|
|
try:
|
|
response = await call_next(request)
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
# 重置上下文变量
|
|
current_session.reset(token)
|
|
return response
|
|
|
|
# 4. 服务基类
|
|
T = TypeVar('T')
|
|
|
|
class BaseService(Generic[T]):
|
|
model: Type[T] = None # 子类必须指定模型
|
|
|
|
@classmethod
|
|
def get_db(cls) -> AsyncSession:
|
|
"""获取当前请求的会话"""
|
|
session = current_session.get()
|
|
if session is None:
|
|
raise RuntimeError("No database session in context. "
|
|
"Make sure to use this service within a request context.")
|
|
return session
|
|
|
|
@classmethod
|
|
async def create(cls, **kwargs) -> T:
|
|
"""通用创建方法"""
|
|
obj = cls.model(**kwargs)
|
|
db = cls.get_db()
|
|
db.add(obj)
|
|
await db.flush()
|
|
return obj
|
|
|
|
@classmethod
|
|
async def get(cls, id: int) -> Optional[T]:
|
|
"""通用获取方法"""
|
|
db = cls.get_db()
|
|
return await db.get(cls.model, id)
|
|
|
|
# 5. 具体服务类
|
|
class UserService(BaseService[User]):
|
|
model = User # 指定模型
|