diff --git a/entity/__init__.py b/entity/__init__.py index 2704de6..dbe57d5 100644 --- a/entity/__init__.py +++ b/entity/__init__.py @@ -13,14 +13,39 @@ from entity.base_entity import DbBaseModel engine = create_async_engine( settings.database_url, - echo=True, # 打印SQL日志(生产环境建议关闭) + echo=False, # 打印SQL日志(生产环境建议关闭) pool_size=10, # 连接池大小 max_overflow=20, # 最大溢出连接数 pool_recycle=3600, # 连接回收时间(秒),解决MySQL超时断开问题【4†source】【5†source】 ) + # 创建异步会话工厂 class EnhanceAsyncSession(AsyncSession): + async def scalar(self, statement: Executable, + params=None, + *, + execution_options=None, + bind_arguments=None, + **kw: Any, ): + + sig = inspect.signature(super().scalar) + if execution_options is None: + default_execution_options = sig.parameters['execution_options'].default + execution_options = default_execution_options + delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE + existing_condition = statement.whereclause + # 组合条件 + if existing_condition is not None: + # 使用and_组合现有条件和逻辑删除条件 + new_condition = and_(existing_condition, delete_condition) + else: + new_condition = delete_condition + # 应用新条件(创建新的Select对象) + statement = statement.where(new_condition) + return await super().scalar(statement, params, execution_options=execution_options, + bind_arguments=bind_arguments, **kw) + async def execute( self, statement: Executable, @@ -34,7 +59,7 @@ class EnhanceAsyncSession(AsyncSession): if execution_options is None: default_execution_options = sig.parameters['execution_options'].default execution_options = default_execution_options - + print("type(statement):{}", type(statement)) if isinstance(statement, Select): print("这是查询语句,过滤逻辑删除") delete_condition = column(Constant.LOGICAL_DELETE_FIELD) == IsDelete.NO_DELETE @@ -46,8 +71,8 @@ class EnhanceAsyncSession(AsyncSession): new_condition = and_(existing_condition, delete_condition) else: new_condition = delete_condition - # 应用新条件(创建新的Select对象) - statement = statement.where(new_condition) + # 应用新条件(创建新的Select对象) + statement = statement.where(new_condition) if isinstance(statement, Delete): # 检查是否跳过软删除(通过execution_options控制) skip_soft_delete = execution_options and execution_options.get("skip_soft_delete", False) diff --git a/entity/base_entity.py b/entity/base_entity.py index f774b49..b666511 100644 --- a/entity/base_entity.py +++ b/entity/base_entity.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, BigInteger +from sqlalchemy import Column, BigInteger, Select, Delete, Update from common.global_enums import IsDelete from utils import current_timestamp @@ -19,3 +19,36 @@ class DbBaseModel(SQLModel, table=False): # class Config: # arbitrary_types_allowed = True + @classmethod + def select(cls, fields=None): + if fields is None: + fields = cls + return Select(fields) + + @classmethod + def delete(cls): + return Delete(cls) + + @classmethod + def delete_by_id(cls, id: str): + return Delete(cls).where(cls.id==id) + + @classmethod + def delete_by_ids(cls, ids: list[str] | str): + if isinstance(ids, str): + ids = [ids] + return Delete(cls).where(cls.id.in_(ids)) + + @classmethod + def update(cls): + return Update(cls) + + @classmethod + def update_by_id(cls, id: str,update_dict: dict): + update_dict.pop("id",None) + return Update(cls).where(cls.id == id).values(**update_dict) + + @classmethod + def update_by_ids(cls, ids: list[str],update_dict: dict): + update_dict.pop("id",None) + return Update(cls).where(cls.id.in_(ids)).values(**update_dict) diff --git a/entity/dto/UserDto.py b/entity/dto/UserDto.py new file mode 100644 index 0000000..fb97cb0 --- /dev/null +++ b/entity/dto/UserDto.py @@ -0,0 +1,9 @@ +from typing import Optional + +from pydantic import Field + +from entity.dto.base import BasePageQueryReq + + +class UserQueryPageReq(BasePageQueryReq): + username: Optional[str]= Field(default=None,description=" asc或 desc") \ No newline at end of file diff --git a/entity/dto/__init__.py b/entity/dto/__init__.py index e69de29..a1655f6 100644 --- a/entity/dto/__init__.py +++ b/entity/dto/__init__.py @@ -0,0 +1,3 @@ +from beartype.claw import beartype_this_package + +beartype_this_package() diff --git a/entity/dto/base.py b/entity/dto/base.py new file mode 100644 index 0000000..32828a6 --- /dev/null +++ b/entity/dto/base.py @@ -0,0 +1,40 @@ +from typing import Optional, List, Any +from typing import Union + +from pydantic import BaseModel, Field + + +class BaseTabelDto(BaseModel): + id: Optional[str] = None + created_time: Optional[int] = None + # created_by: Optional[str] = None + updated_time: Optional[int] = None + # updated_by: Optional[str] = None + is_deleted: Optional[int] = None + +class BaseQueryReq(BaseTabelDto): + desc: Optional[str] = Field(default="desc", description=" asc或 desc") + orderby: Optional[str] = Field(default="created_time", description="根据什么字段排序") + + +class BasePageQueryReq(BaseQueryReq): + page_number: Optional[int] = Field(default=1, description="第几页") + page_size: Optional[int] = Field(default=12, description="一页多少条") + + +class BaseRenameReq(BaseModel): + id: str + name: str + + +class BasePageResp(BaseModel): + page_number: Optional[int] + page_size: Optional[int] + page_count: Optional[int] + desc: Optional[str] + orderby: Optional[str] + count: Optional[int] + data: Optional[List[Any]] + + class Config: + arbitrary_types_allowed = True diff --git a/middleware/db_session.py b/middleware/db_session.py index 19a1af2..50ac395 100644 --- a/middleware/db_session.py +++ b/middleware/db_session.py @@ -4,7 +4,7 @@ from core.global_context import current_session from entity import AsyncSessionLocal class DbSessionMiddleWare(BaseHTTPMiddleware): - async def db_session_middleware(self,request: Request, call_next): + async def dispatch(self,request: Request, call_next): async with AsyncSessionLocal() as session: # 设置会话到上下文变量 token = current_session.set(session) diff --git a/pyproject.toml b/pyproject.toml index 9d1dc7b..faf08c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "sqlmodel>=0.0.25", "aiomysql>=0.2.0", "beartype>=0.21.0", + "fastapi-pagination>=0.14.1", ] [[tool.uv.index]] url = "https://mirrors.aliyun.com/pypi/simple" diff --git a/router/user_app.py b/router/user_app.py new file mode 100644 index 0000000..602e6e0 --- /dev/null +++ b/router/user_app.py @@ -0,0 +1,10 @@ +from fastapi import APIRouter, Query + +from entity.dto.UserDto import UserQueryPageReq +from service.user_service import UserService + +router = APIRouter(prefix="/user", tags=["用户"]) +base_service = UserService +@router.get("/page") +async def page(req:UserQueryPageReq=Query(...)): + return await base_service.get_by_page(req) \ No newline at end of file diff --git a/service/base_service.py b/service/base_service.py index 72fd325..dbf9825 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -1,16 +1,19 @@ from typing import Union, Type, List, Any, TypeVar, Generic +from fastapi_pagination import Params, Page +from fastapi_pagination.ext.sqlalchemy import paginate from pydantic import BaseModel +from sqlalchemy import Select, select, func from sqlalchemy.ext.asyncio import AsyncSession from core.global_context import current_session +from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from utils import get_uuid, current_timestamp -T = TypeVar('T') -class BaseService(Generic[T]): - model: Type[T] # 子类必须指定模型 +class BaseService: + model=None # 子类必须指定模型 @classmethod def get_db(cls) -> AsyncSession: @@ -21,15 +24,6 @@ class BaseService(Generic[T]): "Make sure to use this service within a request context.") return session - @classmethod - async def create(cls, **kwargs) -> T: - """通用创建方法""" - obj = cls.model(**kwargs) - db = cls.get_db() - db.add(obj) - await db.flush() - return obj - @classmethod def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[ BaseModel, List[BaseModel]]: @@ -43,100 +37,21 @@ class BaseService(Generic[T]): dto_list.append(dto(**temp)) return dto_list - @classmethod - def query(cls, cols=None, reverse=None, order_by=None, **kwargs): - """Execute a database query with optional column selection and ordering. - - This method provides a flexible way to query the database with various filters - and sorting options. It supports column selection, sort order control, and - additional filter conditions. - - Args: - cols (list, optional): List of column names to select. If None, selects all columns. - reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order. - order_by (str, optional): Column name to sort results by. - **kwargs: Additional filter conditions passed as keyword arguments. - - Returns: - peewee.ModelSelect: A query result containing matching records. - """ - return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) - - @classmethod - def get_all(cls, cols=None, reverse=None, order_by=None): - """Retrieve all records from the database with optional column selection and ordering. - - This method fetches all records from the model's table with support for - column selection and result ordering. If no order_by is specified and reverse - is True, it defaults to ordering by created_time. - - Args: - cols (list, optional): List of column names to select. If None, selects all columns. - reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order. - order_by (str, optional): Column name to sort results by. Defaults to 'created_time' if reverse is specified. - - Returns: - peewee.ModelSelect: A query containing all matching records. - """ - if cols: - query_records = cls.model.select(*cols) - else: - query_records = cls.model.select() - if reverse is not None: - if not order_by or not hasattr(cls, order_by): - order_by = "created_time" - if reverse is True: - query_records = query_records.order_by(cls.model.getter_by(order_by).desc()) - elif reverse is False: - query_records = query_records.order_by(cls.model.getter_by(order_by).asc()) - return query_records - - @classmethod - def get(cls, **kwargs): - """Get a single record matching the given criteria. - - This method retrieves a single record from the database that matches - the specified filter conditions. - - Args: - **kwargs: Filter conditions as keyword arguments. - - Returns: - Model instance: Single matching record. - - Raises: - peewee.DoesNotExist: If no matching record is found. - """ - return cls.model.get(**kwargs) - - @classmethod - def get_or_none(cls, **kwargs): - """Get a single record or None if not found. - - This method attempts to retrieve a single record matching the given criteria, - returning None if no match is found instead of raising an exception. - - Args: - **kwargs: Filter conditions as keyword arguments. - - Returns: - Model instance or None: Matching record if found, None otherwise. - """ - try: - return cls.model.get(**kwargs) - except peewee.DoesNotExist: - return None @classmethod - def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]): + async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]): if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} - sessions = cls.get_query_session(query_params) - return cls.auto_page(sessions, query_params) + query_entity = cls.get_query_entity(query_params) + return await cls.auto_page(query_entity, query_params) + # @classmethod + # def count_query(cls,query: Select) -> Select: + # # type: ignore + # return select(func.count("*")).select_from(count_subquery) @classmethod - def auto_page(cls, sessions, query_params: Union[dict, BasePageQueryReq] = None, + async def auto_page(cls, query_entity, query_params: Union[dict, BasePageQueryReq] = None, dto_model_class: Type[BaseModel] = None): if not query_params: query_params = {} @@ -146,7 +61,11 @@ class BaseService(Generic[T]): page_size = query_params.get("page_size", 12) desc = query_params.get("desc", "desc") orderby = query_params.get("orderby", "created_time") - data_count = sessions.count() + # data_count = await sessions.count() + session = cls.get_db() + + # data_count = session.scalar(cls.count_query(query_entity)) + data_count =None if data_count == 0: return BasePageResp(**{ "page_number": page_number, @@ -157,18 +76,25 @@ class BaseService(Generic[T]): "data": [], }) if desc == "desc": - sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) + + query_entity = query_entity.order_by(getattr(cls.model,orderby).desc()) else: - sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - sessions = sessions.paginate(int(page_number), int(page_size)) - datas = list(sessions.dicts()) - result = datas + query_entity = query_entity.order_by(getattr(cls.model,orderby).asc()) + query_page_result=await paginate(session, + query_entity, + Params(page=page_number,size=page_size)) + # query_entity = query_entity.offset((page_number - 1) * page_size).limit(page_size) + # query_exec_result = await session.execute(query_entity) + # result = query_exec_result.scalars().all() + # return query_page_result + result = query_page_result.items if dto_model_class is not None: - result = [dto_model_class(**item) for item in datas] + result = [dto_model_class(**item) for item in result] return BasePageResp(**{ "page_number": page_number, "page_size": page_size, - "count": data_count, + "page_count": query_page_result.pages, + "count": query_page_result.total, "desc": desc, "orderby": orderby, "data": result, @@ -181,13 +107,13 @@ class BaseService(Generic[T]): query_params = {k: v for k, v in query_params.items() if v is not None} desc = query_params.get("desc", "desc") orderby = query_params.get("orderby", "created_time") - sessions = cls.get_query_session(query_params) + sessions = cls.get_query_entity(query_params) if desc == "desc": sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) else: sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - return list(sessions.dicts()) + return sessions.scalars().all() @classmethod def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]: @@ -197,12 +123,12 @@ class BaseService(Generic[T]): desc = query_params.get("desc", "desc") orderby = query_params.get("orderby", "created_time") sessions = cls.model.select(cls.model.id) - sessions = cls.get_query_session(query_params, sessions) + sessions = cls.get_query_entity(query_params, sessions) if desc == "desc": sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) else: sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - return [item["id"] for item in list(sessions.dicts())] + return [item["id"] for item in sessions.scalars().all()] @classmethod def save(cls, **kwargs): @@ -217,32 +143,10 @@ class BaseService(Generic[T]): Returns: Model instance: The created record object. """ - + # todo sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj > 0 - @classmethod - def insert(cls, **kwargs): - """Insert a new record with automatic ID and timestamps. - - This method creates a new record with automatically generated ID and timestamp fields. - It handles the creation of created_time, create_date, updated_time, and update_date fields. - - Args: - **kwargs: Record field values as keyword arguments. - - Returns: - Model instance: The newly created record object. - """ - if "id" not in kwargs: - kwargs["id"] = get_uuid() - - kwargs["created_time"] = current_timestamp() - # kwargs["create_date"] = datetime_format(datetime.now()) - kwargs["updated_time"] = current_timestamp() - # kwargs["update_date"] = datetime_format(datetime.now()) - sample_obj = cls.model(**kwargs).save(force_insert=True) - return sample_obj > 0 @classmethod def insert_many(cls, data_list, batch_size=100): @@ -266,6 +170,18 @@ class BaseService(Generic[T]): for i in range(0, len(data_list), batch_size): cls.model.insert_many(data_list[i: i + batch_size]).execute() + @classmethod + def update_by_id(cls, pid, data): + # Update a single record by ID + # Args: + # pid: Record ID + # data: Updated field values + # Returns: + # Number of records updated + data["updated_time"] = current_timestamp() + num = cls.model.update(data).where(cls.model.id == pid).execute() + return num + @classmethod def update_many_by_id(cls, data_list): """Update multiple records by their IDs. @@ -284,18 +200,6 @@ class BaseService(Generic[T]): # data["update_date"] = datetime_format(datetime.now()) cls.model.update(data).where(cls.model.id == data["id"]).execute() - @classmethod - def updated_by_id(cls, pid, data): - # Update a single record by ID - # Args: - # pid: Record ID - # data: Updated field values - # Returns: - # Number of records updated - data["updated_time"] = current_timestamp() - # data["update_date"] = datetime_format(datetime.now()) - num = cls.model.update(data).where(cls.model.id == pid).execute() - return num > 0 @classmethod def get_by_id(cls, pid): @@ -326,101 +230,16 @@ class BaseService(Generic[T]): objs = cls.model.select() return objs.where(cls.model.id.in_(pids)) - @classmethod - def get_last_by_create_time(cls): - # Get multiple records by their IDs - # Args: - # pids: List of record IDs - # cols: List of columns to select - # Returns: - # Query of matching records - latest = cls.model.select().order_by(cls.model.created_time.desc()).first() - return latest @classmethod def delete_by_id(cls, pid): - # Delete a record by ID - # Args: - # pid: Record ID - # Returns: - # Number of records deleted - return cls.model.delete().where(cls.model.id == pid).execute() - + ... @classmethod def delete_by_ids(cls, pids): - # Delete multiple records by their IDs - # Args: - # pids: List of record IDs - # Returns: - # Number of records deleted - with DB.atomic(): - res = cls.model.delete().where(cls.model.id.in_(pids)).execute() - return res - - @classmethod - def filter_delete(cls, filters): - # Delete records matching given filters - # Args: - # filters: List of filter conditions - # Returns: - # Number of records deleted - with DB.atomic(): - num = cls.model.delete().where(*filters).execute() - return num - - @classmethod - def filter_update(cls, filters, update_data): - # Update records matching given filters - # Args: - # filters: List of filter conditions - # update_data: Updated field values - # Returns: - # Number of records updated - with DB.atomic(): - return cls.model.update(update_data).where(*filters).execute() - - @staticmethod - def cut_list(tar_list, n): - # Split a list into chunks of size n - # Args: - # tar_list: List to split - # n: Chunk size - # Returns: - # List of tuples containing chunks - length = len(tar_list) - arr = range(length) - result = [tuple(tar_list[x: (x + n)]) for x in arr[::n]] - return result + ... @classmethod - def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): - # Get records matching IN clause filters with optional column selection - # Args: - # in_key: Field name for IN clause - # in_filters_list: List of values for IN clause - # filters: Additional filter conditions - # cols: List of columns to select - # Returns: - # List of matching records - in_filters_tuple_list = cls.cut_list(in_filters_list, 20) - if not filters: - filters = [] - res_list = [] - if cols: - for i in in_filters_tuple_list: - query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters) - if query_records: - res_list.extend([query_record for query_record in query_records]) - else: - for i in in_filters_tuple_list: - query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters) - if query_records: - res_list.extend([query_record for query_record in query_records]) - return res_list - - @classmethod - def get_query_session(cls, query_params, sessions=None): - + def get_query_entity(cls, query_params, sessions=None): if sessions is None: sessions = cls.model.select() for key, value in query_params.items(): @@ -433,29 +252,10 @@ class BaseService(Generic[T]): def get_data_count(cls, query_params: dict = None): if not query_params: raise Exception("参数为空") - sessions = cls.get_query_session(query_params) + sessions = cls.get_query_entity(query_params) return sessions.count() @classmethod def is_exist(cls, query_params: dict = None): return cls.get_data_count(query_params) > 0 - @classmethod - def update_by_id(cls, pid, data): - # Update a single record by ID - # Args: - # pid: Record ID - # data: Updated field values - # Returns: - # Number of records updated - data["updated_time"] = current_timestamp() - num = cls.model.update(data).where(cls.model.id == pid).execute() - return num - - @classmethod - def check_base_permission(cls, model_data): - if isinstance(model_data, dict): - if model_data.get("created_by") != get_current_user().id: - raise RuntimeError("无操作权限,该操作仅创建者有此权限") - if model_data.created_by != get_current_user().id: - raise RuntimeError("无操作权限,该操作仅创建者有此权限") diff --git a/service/user_service.py b/service/user_service.py index 087d500..bc1cb81 100644 --- a/service/user_service.py +++ b/service/user_service.py @@ -1,62 +1,7 @@ -from contextvars import ContextVar -from contextvars import ContextVar -from typing import Optional, TypeVar, Generic, Type - -from fastapi import Request -from sqlalchemy.ext.asyncio import AsyncSession - -from entity import AsyncSessionLocal from entity.user import User +from service.base_service import BaseService -# 1. 创建上下文变量存储当前会话 -current_session: ContextVar[Optional[AsyncSession]] = ContextVar("current_session", default=None) - -# 3. 中间件:管理请求生命周期和会话 -async def db_session_middleware(request: Request, call_next): - async with AsyncSessionLocal() as session: - # 设置会话到上下文变量 - token = current_session.set(session) - try: - response = await call_next(request) - await session.commit() - except Exception: - await session.rollback() - raise - finally: - # 重置上下文变量 - current_session.reset(token) - return response - -# 4. 服务基类 -T = TypeVar('T') - -class BaseService(Generic[T]): - model: Type[T] = None # 子类必须指定模型 - - @classmethod - def get_db(cls) -> AsyncSession: - """获取当前请求的会话""" - session = current_session.get() - if session is None: - raise RuntimeError("No database session in context. " - "Make sure to use this service within a request context.") - return session - - @classmethod - async def create(cls, **kwargs) -> T: - """通用创建方法""" - obj = cls.model(**kwargs) - db = cls.get_db() - db.add(obj) - await db.flush() - return obj - - @classmethod - async def get(cls, id: int) -> Optional[T]: - """通用获取方法""" - db = cls.get_db() - return await db.get(cls.model, id) # 5. 具体服务类 -class UserService(BaseService[User]): +class UserService(BaseService): model = User # 指定模型 diff --git a/uv.lock b/uv.lock index 033a643..7bc9a7c 100644 --- a/uv.lock +++ b/uv.lock @@ -234,6 +234,20 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/e5/a6/5aa862489a2918a096166fd98d9fe86b7fd53c607678b3fa9d8c432d88d5/fastapi_cloud_cli-0.1.5-py3-none-any.whl", hash = "sha256:d80525fb9c0e8af122370891f9fa83cf5d496e4ad47a8dd26c0496a6c85a012a" }, ] +[[package]] +name = "fastapi-pagination" +version = "0.14.1" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "fastapi" }, + { name = "pydantic" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/60/38/e27d94adb7050d88cddccddb56ab51244a2a9a30377eb5487de8f2b22959/fastapi_pagination-0.14.1.tar.gz", hash = "sha256:d045f8c678cef69ac006236f32a6d808ae6ca549360f79e92eefe2528f3e6337" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/83/6e/7b6e63c3c004c4805fe8795f8355b30cc72f23aa2162f4072e6035ef1e84/fastapi_pagination-0.14.1-py3-none-any.whl", hash = "sha256:e5b698cab368b525f3b2ea2605c77dc22b00f50f67ac7d382dce8d3243fe60dc" }, +] + [[package]] name = "filelock" version = "3.15.4" @@ -1004,6 +1018,7 @@ dependencies = [ { name = "beartype" }, { name = "cachetools" }, { name = "fastapi", extra = ["standard"] }, + { name = "fastapi-pagination" }, { name = "filelock" }, { name = "httpx-sse" }, { name = "itsdangerous" }, @@ -1023,6 +1038,7 @@ requires-dist = [ { name = "beartype", specifier = ">=0.21.0" }, { name = "cachetools", specifier = "==5.3.3" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.116.1" }, + { name = "fastapi-pagination", specifier = ">=0.14.1" }, { name = "filelock", specifier = "==3.15.4" }, { name = "httpx-sse", specifier = ">=0.4.1" }, { name = "itsdangerous", specifier = "==2.1.2" },