diff --git a/entity/base_entity.py b/entity/base_entity.py index b666511..df6cdbd 100644 --- a/entity/base_entity.py +++ b/entity/base_entity.py @@ -25,28 +25,15 @@ class DbBaseModel(SQLModel, table=False): fields = cls return Select(fields) + @classmethod def delete(cls): return Delete(cls) - @classmethod - def delete_by_id(cls, id: str): - return Delete(cls).where(cls.id==id) - - @classmethod - def delete_by_ids(cls, ids: list[str] | str): - if isinstance(ids, str): - ids = [ids] - return Delete(cls).where(cls.id.in_(ids)) - @classmethod def update(cls): return Update(cls) - @classmethod - def update_by_id(cls, id: str,update_dict: dict): - update_dict.pop("id",None) - return Update(cls).where(cls.id == id).values(**update_dict) @classmethod def update_by_ids(cls, ids: list[str],update_dict: dict): diff --git a/service/base_service.py b/service/base_service.py index dbf9825..a764299 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -1,19 +1,25 @@ -from typing import Union, Type, List, Any, TypeVar, Generic +from typing import Union, Type, List, Any -from fastapi_pagination import Params, Page +from fastapi_pagination import Params from fastapi_pagination.ext.sqlalchemy import paginate from pydantic import BaseModel -from sqlalchemy import Select, select, func +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from core.global_context import current_session from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from utils import get_uuid, current_timestamp +""" +session.execute: 执行任意数据库操作语句,返回结果需要额外处理获取 数据形式:Row对象(类似元组) +-- exec_result.scalars().all(): 从session.execute的执行结果中获取全部的数据 +session.scalars: 只适合单模型查询(不适合指定列或连表查询),返回结果需要处理 数据形式:直接返回模型实例(如User对象) +-- 处理exec_result.all(): 从session.scalars的执行结果中获取全部的数据 +session.scalar: 直接明确获取一条数据,可以直接返回,无需额外处理 - +""" class BaseService: - model=None # 子类必须指定模型 + model = None # 子类必须指定模型 @classmethod def get_db(cls) -> AsyncSession: @@ -23,13 +29,24 @@ class BaseService: raise RuntimeError("No database session in context. " "Make sure to use this service within a request context.") return session - @classmethod - def entity_conversion_dto(cls, entity_data: Union[list, model], dto: Type[BaseModel]) -> Union[ + def get_query_stmt(cls, query_params, sessions=None,*,fields:list=None): + if sessions is None: + if fields: + sessions = cls.model.select(*fields) + else: + sessions = cls.model.select() + for key, value in query_params.items(): + if hasattr(cls.model, key) and value is not None: + field = getattr(cls.model, key) + sessions = sessions.where(field == value) + return sessions + @classmethod + def entity_conversion_dto(cls, entity_data: Union[list, BaseModel], dto: Type[BaseModel]) -> Union[ BaseModel, List[BaseModel]]: dto_list = [] if not isinstance(entity_data, list): - return dto(**entity_data.to_dict()) + return dto(**entity_data.model_dump()) for entity in entity_data: temp = entity if not isinstance(entity, dict): @@ -37,22 +54,17 @@ class BaseService: dto_list.append(dto(**temp)) return dto_list - @classmethod async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]): if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} - query_entity = cls.get_query_entity(query_params) - return await cls.auto_page(query_entity, query_params) - # @classmethod - # def count_query(cls,query: Select) -> Select: - # # type: ignore - # return select(func.count("*")).select_from(count_subquery) + query_stmt = cls.get_query_stmt(query_params) + return await cls.auto_page(query_stmt, query_params) @classmethod - async def auto_page(cls, query_entity, query_params: Union[dict, BasePageQueryReq] = None, - dto_model_class: Type[BaseModel] = None): + async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, + dto_model_class: Type[BaseModel] = None): if not query_params: query_params = {} if not isinstance(query_params, dict): @@ -61,11 +73,8 @@ class BaseService: page_size = query_params.get("page_size", 12) desc = query_params.get("desc", "desc") orderby = query_params.get("orderby", "created_time") - # data_count = await sessions.count() session = cls.get_db() - - # data_count = session.scalar(cls.count_query(query_entity)) - data_count =None + data_count = None if data_count == 0: return BasePageResp(**{ "page_number": page_number, @@ -76,17 +85,12 @@ class BaseService: "data": [], }) if desc == "desc": - - query_entity = query_entity.order_by(getattr(cls.model,orderby).desc()) + query_stmt = query_stmt.order_by(getattr(cls.model, orderby).desc()) else: - query_entity = query_entity.order_by(getattr(cls.model,orderby).asc()) - query_page_result=await paginate(session, - query_entity, - Params(page=page_number,size=page_size)) - # query_entity = query_entity.offset((page_number - 1) * page_size).limit(page_size) - # query_exec_result = await session.execute(query_entity) - # result = query_exec_result.scalars().all() - # return query_page_result + query_stmt = query_stmt.order_by(getattr(cls.model, orderby).asc()) + query_page_result = await paginate(session, + query_stmt, + Params(page=page_number, size=page_size)) result = query_page_result.items if dto_model_class is not None: result = [dto_model_class(**item) for item in result] @@ -101,161 +105,112 @@ class BaseService: }) @classmethod - def get_list(cls, query_params: Union[dict, BaseQueryReq]): + async def get_list(cls, query_params: Union[dict, BaseQueryReq]): if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} desc = query_params.get("desc", "desc") orderby = query_params.get("orderby", "created_time") - sessions = cls.get_query_entity(query_params) + query_stmt = cls.get_query_stmt(query_params) + field = getattr(cls.model, orderby) if desc == "desc": - sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) + query_stmt = query_stmt.order_by(field.desc()) else: - sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - - return sessions.scalars().all() + query_stmt = query_stmt.order_by(field.asc()) + session = cls.get_db() + exec_result = await session.execute(query_stmt) + return exec_result.scalars().all() @classmethod - def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]: + async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} desc = query_params.get("desc", "desc") orderby = query_params.get("orderby", "created_time") - sessions = cls.model.select(cls.model.id) - sessions = cls.get_query_entity(query_params, sessions) + query_stmt = cls.model.select(cls.model.id) + query_stmt = cls.get_query_stmt(query_params, query_stmt) if desc == "desc": - sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) + query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).desc()) else: - sessions = sessions.order_by(cls.model.getter_by(orderby).asc()) - return [item["id"] for item in sessions.scalars().all()] + query_stmt = query_stmt.order_by(cls.model.getter_by(orderby).asc()) + session = cls.get_db() + exec_result = await session.execute(query_stmt) + return [item["id"] for item in exec_result.scalars().all()] @classmethod - def save(cls, **kwargs): - """Save a new record to database. - - This method creates a new record in the database with the provided field values, - forcing an insert operation rather than an update. - - Args: - **kwargs: Record field values as keyword arguments. - - Returns: - Model instance: The created record object. - """ - # todo - sample_obj = cls.model(**kwargs).save(force_insert=True) - return sample_obj > 0 - + async def save(cls, **kwargs): + sample_obj = cls.model(**kwargs) + session = cls.get_db() + session.add(sample_obj) + await session.flush() + return sample_obj @classmethod - def insert_many(cls, data_list, batch_size=100): - """Insert multiple records in batches. - - This method efficiently inserts multiple records into the database using batch processing. - It automatically sets creation timestamps for all records. - - Args: - data_list (list): List of dictionaries containing record data to insert. - batch_size (int, optional): Number of records to insert in each batch. Defaults to 100. - """ - - with DB.atomic(): + async def insert_many(cls, data_list, batch_size=100): + async with cls.get_db() as session: for d in data_list: - if not d.get("id", None): d["id"] = get_uuid() - d["created_time"] = current_timestamp() - # d["create_date"] = datetime_format(datetime.now()) + for i in range(0, len(data_list), batch_size): - cls.model.insert_many(data_list[i: i + batch_size]).execute() + session.add_all(data_list[i: i + batch_size]) @classmethod - def update_by_id(cls, pid, data): - # Update a single record by ID - # Args: - # pid: Record ID - # data: Updated field values - # Returns: - # Number of records updated - data["updated_time"] = current_timestamp() - num = cls.model.update(data).where(cls.model.id == pid).execute() - return num + async def update_by_id(cls, pid, data): + update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) + session = cls.get_db() + result = await session.execute(update_stmt) + return result.rowcount @classmethod - def update_many_by_id(cls, data_list): - """Update multiple records by their IDs. - - This method updates multiple records in the database, identified by their IDs. - It automatically updates the updated_time and update_date fields for each record. - - Args: - data_list (list): List of dictionaries containing record data to update. - Each dictionary must include an 'id' field. - """ - - with DB.atomic(): + async def update_many_by_id(cls, data_list): + async with cls.get_db() as session: for data in data_list: - data["updated_time"] = current_timestamp() - # data["update_date"] = datetime_format(datetime.now()) - cls.model.update(data).where(cls.model.id == data["id"]).execute() - + stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) + await session.execute(stmt) @classmethod - def get_by_id(cls, pid): - # Get a record by ID - # Args: - # pid: Record ID - # Returns: - # Tuple of (success, record) - try: - obj = cls.model.get_or_none(cls.model.id == pid) - if obj: - return True, obj - except Exception: - pass - return False, None + async def get_by_id(cls, pid): + stmt = cls.model.select(cls.model.id == pid) + session = cls.get_db() + return await session.scalar(stmt) @classmethod - def get_by_ids(cls, pids, cols=None): - # Get multiple records by their IDs - # Args: - # pids: List of record IDs - # cols: List of columns to select - # Returns: - # Query of matching records + async def get_by_ids(cls, pids, cols=None): if cols: objs = cls.model.select(*cols) else: objs = cls.model.select() - return objs.where(cls.model.id.in_(pids)) - + stmt = objs.where(cls.model.id.in_(pids)) + session = cls.get_db() + result = await session.scalars(stmt) + return result.all() @classmethod - def delete_by_id(cls, pid): - ... - @classmethod - def delete_by_ids(cls, pids): - ... + async def delete_by_id(cls, pid): + del_stmt = cls.model.delete().where(cls.model.id == pid) + session = cls.get_db() + exec_result = await session.execute(del_stmt) + return exec_result.rowcount @classmethod - def get_query_entity(cls, query_params, sessions=None): - if sessions is None: - sessions = cls.model.select() - for key, value in query_params.items(): - if hasattr(cls.model, key): - field = getattr(cls.model, key) - sessions = sessions.where(field == value) - return sessions + async def delete_by_ids(cls, pids): + session = cls.get_db() + del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) + result = await session.execute(del_stmt) + return result.rowcount @classmethod - def get_data_count(cls, query_params: dict = None): + async def get_data_count(cls, query_params: dict = None)->int: if not query_params: raise Exception("参数为空") - sessions = cls.get_query_entity(query_params) - return sessions.count() + stmt = cls.get_query_stmt(query_params,fields=[func.count(cls.model.id)]) + # stmt = cls.get_query_stmt(query_params) + session = cls.get_db() + return await session.scalar(stmt) - @classmethod - def is_exist(cls, query_params: dict = None): - return cls.get_data_count(query_params) > 0 + @classmethod + async def is_exist(cls, query_params: dict = None): + return await cls.get_data_count(query_params) > 0