perf: 优化BaseService的类型提示及配置文件的设置

main
chenzhirong 4 months ago
parent 063ef1f443
commit 175c413b73

@ -1,3 +1,5 @@
SERVICE_CONF=application.yaml CONF_YAML_NAME=application.yaml
LOAD_YAML=true
DEBUG=true DEBUG=true
DATABASE_URL=mysql+aiomysql://username:password@host:port/database_name DATABASE_URL=mysql+aiomysql://username:password@host:port/database_name

@ -58,6 +58,7 @@ uv add xxx
```bash ```bash
# 复制环境变量模板并修改 # 复制环境变量模板并修改
cp .env.template .env cp .env.template .env
# 如需启用YAML 配置文件支持,需要将.env中 的LOAD_YAML设置为 true并执行以下操作
cp application-template.yaml application.yaml cp application-template.yaml application.yaml
# 根据需要编辑.env和application.yaml文件 # 根据需要编辑.env和application.yaml文件
``` ```
@ -72,7 +73,6 @@ python main.py
## API路由 ## API路由
- `/monitor` - 系统监控相关API - `/monitor` - 系统监控相关API
- `/test` - 测试相关API
- `/user` - 用户相关API - `/user` - 用户相关API
## 开发指南 ## 开发指南

@ -1,6 +1,15 @@
from enum import IntEnum from enum import IntEnum, StrEnum
class IsDelete(IntEnum): class IsDelete(IntEnum):
NO_DELETE = 0 NO_DELETE = 0
DELETE = 1 DELETE = 1
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'
TTS = 'tts'

@ -1,4 +1,3 @@
import asyncio
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
@ -13,7 +12,8 @@ from config import settings
__all__ = ["app"] __all__ = ["app"]
from entity import init_db, close_engine from entity import close_engine
from entity.db_models import init_db
from exceptions.global_exc import configure_exception from exceptions.global_exc import configure_exception
@ -32,9 +32,10 @@ async def lifespan(app: FastAPI):
# 2. 把40个线程改成80 # 2. 把40个线程改成80
limiter.total_tokens = 80 limiter.total_tokens = 80
await init_db() await init_db()
yield # 上面是启动时做的操作,下面是关闭时做的操作 yield # 上面是启动时做的操作,下面是关闭时做的操作
await close_engine() await close_engine()
# FastAPI应用初始化 # FastAPI应用初始化
app = FastAPI( app = FastAPI(
title="Fast API", title="Fast API",
@ -113,4 +114,3 @@ client_urls_prefix = [
for pages_dir in pages_dirs for pages_dir in pages_dirs
for path in search_pages_path(pages_dir) for path in search_pages_path(pages_dir)
] ]

@ -14,6 +14,7 @@ class BaseSettings(Base):
env_file_encoding = 'utf-8' env_file_encoding = 'utf-8'
extra = 'allow' extra = 'allow'
class Settings(BaseSettings): class Settings(BaseSettings):
"""应用配置 """应用配置
server目录为后端项目根目录, 在该目录下创建 "config.env" 文件, 写入环境变量(默认大写)会自动加载, 并覆盖同名配置(小写) server目录为后端项目根目录, 在该目录下创建 "config.env" 文件, 写入环境变量(默认大写)会自动加载, 并覆盖同名配置(小写)
@ -25,7 +26,8 @@ class Settings(BaseSettings):
# 模式 # 模式
mode: str = 'dev' # dev, prod mode: str = 'dev' # dev, prod
debug: bool = False # dev, prod debug: bool = False # dev, prod
service_conf: str = 'application.yaml' # dev, prod load_yaml: bool = True # 是否开启加载 yaml 配置文件
conf_yaml_name: str = 'application.yaml' # dev, prod
# 版本 # 版本
api_version: str = '/v1' api_version: str = '/v1'
# 时区 # 时区
@ -35,23 +37,26 @@ class Settings(BaseSettings):
# Redis键前缀 # Redis键前缀
redis_prefix: str = 'agent:' redis_prefix: str = 'agent:'
# 当前域名 # 当前域名
host_ip:str = '0.0.0.0' host_ip: str = '0.0.0.0'
host_port: int = 8080 host_port: int = 8080
#sql驱动连接 # sql驱动连接
database_url: str = '' database_url: str = ''
# yaml配置 # yaml配置
yaml_config: dict = {} yaml_config: dict = {}
@lru_cache() @lru_cache()
def get_settings() -> Settings: def get_settings() -> Settings:
"""获取并缓存应用配置""" """获取并缓存应用配置"""
# 读取server目录下的配置 # 读取server目录下的配置
load_dotenv() load_dotenv()
settings = Settings() settings = Settings()
yaml_config = file_utils.load_yaml_conf(settings.service_conf) if settings.load_yaml:
# 将YAML配置存储到Settings实例中 yaml_config = file_utils.load_yaml_conf(settings.conf_yaml_name)
settings.yaml_config = yaml_config # 将YAML配置存储到Settings实例中
for k, v in settings.yaml_config.items(): settings.yaml_config = yaml_config
if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default: for k, v in settings.yaml_config.items():
setattr(settings, k, v) if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default:
setattr(settings, k, v)
return settings return settings

@ -1,10 +1,8 @@
import asyncio
import inspect import inspect
from typing import Any from typing import Any
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlmodel import SQLModel
from common.constant import Constant from common.constant import Constant
from common.global_enums import IsDelete from common.global_enums import IsDelete
@ -143,24 +141,3 @@ AsyncSessionLocal = async_sessionmaker(
# 关闭引擎 # 关闭引擎
async def close_engine(): async def close_engine():
await engine.dispose() await engine.dispose()
# 初始化数据库表(异步执行)
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
if __name__ == '__main__':
import user
async def main():
try:
await init_db()
finally:
await close_engine() # 确保引擎关闭
asyncio.run(main())

@ -0,0 +1,15 @@
from sqlmodel import SQLModel
from entity import DbBaseModel, engine
# 初始化数据库表(异步执行)
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
class User(DbBaseModel, table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str
password: str

@ -1,7 +0,0 @@
from entity.base_entity import DbBaseModel
class User(DbBaseModel,table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str
password: str

@ -1,13 +0,0 @@
from fastapi import APIRouter
from config import settings
from utils.file_utils import get_project_base_directory
router = APIRouter()
@router.get("/test")
async def test():
return {
"settings":settings,
"base_path": get_project_base_directory(),
}

@ -1,12 +1,14 @@
from typing import Union, Type, List, Any from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from core.global_context import current_session from core.global_context import current_session
from entity import DbBaseModel
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
from utils import get_uuid from utils import get_uuid
@ -18,10 +20,10 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询)
session.scalar: 直接明确获取一条数据可以直接返回无需额外处理 session.scalar: 直接明确获取一条数据可以直接返回无需额外处理
""" """
T = TypeVar('T', bound=DbBaseModel)
class BaseService(Generic[T]):
class BaseService: model: Type[T] # 子类必须指定模型
model = None # 子类必须指定模型
@classmethod @classmethod
def get_db(cls) -> AsyncSession: def get_db(cls) -> AsyncSession:
@ -64,7 +66,7 @@ class BaseService:
pass pass
@classmethod @classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]): async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -73,7 +75,7 @@ class BaseService:
@classmethod @classmethod
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None): dto_model_class: Type[BaseModel] = None)->BasePageResp:
if not query_params: if not query_params:
query_params = {} query_params = {}
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
@ -114,7 +116,7 @@ class BaseService:
}) })
@classmethod @classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq]): async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -128,10 +130,10 @@ class BaseService:
query_stmt = query_stmt.order_by(field.asc()) query_stmt = query_stmt.order_by(field.asc())
session = cls.get_db() session = cls.get_db()
exec_result = await session.execute(query_stmt) exec_result = await session.execute(query_stmt)
return exec_result.scalars().all() return list(exec_result.scalars().all())
@classmethod @classmethod
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]: async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[str]:
if not isinstance(query_params, dict): if not isinstance(query_params, dict):
query_params = query_params.model_dump() query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None} query_params = {k: v for k, v in query_params.items() if v is not None}
@ -148,7 +150,7 @@ class BaseService:
return [item["id"] for item in exec_result.scalars().all()] return [item["id"] for item in exec_result.scalars().all()]
@classmethod @classmethod
async def save(cls, **kwargs): async def save(cls, **kwargs)->T:
sample_obj = cls.model(**kwargs) sample_obj = cls.model(**kwargs)
session = cls.get_db() session = cls.get_db()
session.add(sample_obj) session.add(sample_obj)
@ -156,7 +158,7 @@ class BaseService:
return sample_obj return sample_obj
@classmethod @classmethod
async def insert_many(cls, data_list, batch_size=100): async def insert_many(cls, data_list, batch_size=100)->None:
async with cls.get_db() as session: async with cls.get_db() as session:
for d in data_list: for d in data_list:
if not d.get("id", None): if not d.get("id", None):
@ -166,27 +168,27 @@ class BaseService:
session.add_all(data_list[i: i + batch_size]) session.add_all(data_list[i: i + batch_size])
@classmethod @classmethod
async def update_by_id(cls, pid, data): async def update_by_id(cls, pid, data)-> int:
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
session = cls.get_db() session = cls.get_db()
result = await session.execute(update_stmt) result = await session.execute(update_stmt)
return result.rowcount return result.rowcount()
@classmethod @classmethod
async def update_many_by_id(cls, data_list): async def update_many_by_id(cls, data_list)->None:
async with cls.get_db() as session: async with cls.get_db() as session:
for data in data_list: for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt) await session.execute(stmt)
@classmethod @classmethod
async def get_by_id(cls, pid): async def get_by_id(cls, pid)->T:
stmt = cls.model.select(cls.model.id == pid) stmt = cls.model.select(cls.model.id == pid)
session = cls.get_db() session = cls.get_db()
return await session.scalar(stmt) return await session.scalar(stmt)
@classmethod @classmethod
async def get_by_ids(cls, pids, cols=None): async def get_by_ids(cls, pids, cols=None)->List[T]:
if cols: if cols:
objs = cls.model.select(*cols) objs = cls.model.select(*cols)
else: else:
@ -194,21 +196,21 @@ class BaseService:
stmt = objs.where(cls.model.id.in_(pids)) stmt = objs.where(cls.model.id.in_(pids))
session = cls.get_db() session = cls.get_db()
result = await session.scalars(stmt) result = await session.scalars(stmt)
return result.all() return list(result.all())
@classmethod @classmethod
async def delete_by_id(cls, pid): async def delete_by_id(cls, pid)-> int:
del_stmt = cls.model.delete().where(cls.model.id == pid) del_stmt = cls.model.delete().where(cls.model.id == pid)
session = cls.get_db() session = cls.get_db()
exec_result = await session.execute(del_stmt) exec_result = await session.execute(del_stmt)
return exec_result.rowcount return exec_result.rowcount()
@classmethod @classmethod
async def delete_by_ids(cls, pids): async def delete_by_ids(cls, pids)-> int:
session = cls.get_db() session = cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt) result = await session.execute(del_stmt)
return result.rowcount return result.rowcount()
@classmethod @classmethod
async def get_data_count(cls, query_params: dict = None) -> int: async def get_data_count(cls, query_params: dict = None) -> int:
@ -220,5 +222,5 @@ class BaseService:
return await session.scalar(stmt) return await session.scalar(stmt)
@classmethod @classmethod
async def is_exist(cls, query_params: dict = None): async def is_exist(cls, query_params: dict = None) -> bool:
return await cls.get_data_count(query_params) > 0 return await cls.get_data_count(query_params) > 0

@ -1,7 +1,7 @@
from entity.user import User from entity.db_models import User
from service.base_service import BaseService from service.base_service import BaseService
# 5. 具体服务类 # 5. 具体服务类
class UserService(BaseService): class UserService(BaseService[User]):
model = User # 指定模型 model = User # 指定模型

@ -1,66 +0,0 @@
import asyncio
from sqlalchemy import delete
from sqlmodel import select
from entity import AsyncSessionLocal, close_engine
from entity.user import User
async def add():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
# session = get_db_session()
async with AsyncSessionLocal() as session:
for user in add_list:
session.add(user)
await session.commit()
await close_engine()
print("完成")
async def remove():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
async with AsyncSessionLocal() as session:
for i in add_list:
delete_sql = delete(User).where(User.id == i.id)
await session.execute(delete_sql)
await session.commit()
await close_engine()
print("======================删除完成=================")
async def remove_one():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
async with AsyncSessionLocal() as session:
# for i in add_list:
# delete_sql = delete(User).where(User.id == i.id)
user=await session.execute(select(User))
user = user.scalars().first()
session.delete(user)
await session.commit()
await close_engine()
print("======================删除完成=================")
async def get_all():
async with AsyncSessionLocal() as session:
sql = select(User)
result = await session.execute(sql)
print(result.scalars().all())
await session.commit()
await close_engine()
print("====================get_all==================")
def test(aa="ddd"):
print("aa=", aa)
if __name__ == "__main__":
# asyncio.run(add())
asyncio.run(get_all())
# asyncio.run(remove_one())
asyncio.run(remove())
asyncio.run(get_all())
# test(None)
Loading…
Cancel
Save