perf: 优化BaseService的类型提示及配置文件的设置

main
chenzhirong 4 months ago
parent 063ef1f443
commit 175c413b73

@ -1,3 +1,5 @@
SERVICE_CONF=application.yaml
CONF_YAML_NAME=application.yaml
LOAD_YAML=true
DEBUG=true
DATABASE_URL=mysql+aiomysql://username:password@host:port/database_name

@ -58,6 +58,7 @@ uv add xxx
```bash
# 复制环境变量模板并修改
cp .env.template .env
# 如需启用YAML 配置文件支持,需要将.env中 的LOAD_YAML设置为 true并执行以下操作
cp application-template.yaml application.yaml
# 根据需要编辑.env和application.yaml文件
```
@ -72,7 +73,6 @@ python main.py
## API路由
- `/monitor` - 系统监控相关API
- `/test` - 测试相关API
- `/user` - 用户相关API
## 开发指南

@ -1,6 +1,15 @@
from enum import IntEnum
from enum import IntEnum, StrEnum
class IsDelete(IntEnum):
NO_DELETE = 0
DELETE = 1
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'
TTS = 'tts'

@ -1,4 +1,3 @@
import asyncio
import sys
from contextlib import asynccontextmanager
from importlib.util import module_from_spec, spec_from_file_location
@ -13,7 +12,8 @@ from config import settings
__all__ = ["app"]
from entity import init_db, close_engine
from entity import close_engine
from entity.db_models import init_db
from exceptions.global_exc import configure_exception
@ -32,9 +32,10 @@ async def lifespan(app: FastAPI):
# 2. 把40个线程改成80
limiter.total_tokens = 80
await init_db()
yield # 上面是启动时做的操作,下面是关闭时做的操作
yield # 上面是启动时做的操作,下面是关闭时做的操作
await close_engine()
# FastAPI应用初始化
app = FastAPI(
title="Fast API",
@ -113,4 +114,3 @@ client_urls_prefix = [
for pages_dir in pages_dirs
for path in search_pages_path(pages_dir)
]

@ -14,6 +14,7 @@ class BaseSettings(Base):
env_file_encoding = 'utf-8'
extra = 'allow'
class Settings(BaseSettings):
"""应用配置
server目录为后端项目根目录, 在该目录下创建 "config.env" 文件, 写入环境变量(默认大写)会自动加载, 并覆盖同名配置(小写)
@ -25,7 +26,8 @@ class Settings(BaseSettings):
# 模式
mode: str = 'dev' # dev, prod
debug: bool = False # dev, prod
service_conf: str = 'application.yaml' # dev, prod
load_yaml: bool = True # 是否开启加载 yaml 配置文件
conf_yaml_name: str = 'application.yaml' # dev, prod
# 版本
api_version: str = '/v1'
# 时区
@ -35,23 +37,26 @@ class Settings(BaseSettings):
# Redis键前缀
redis_prefix: str = 'agent:'
# 当前域名
host_ip:str = '0.0.0.0'
host_ip: str = '0.0.0.0'
host_port: int = 8080
#sql驱动连接
# sql驱动连接
database_url: str = ''
# yaml配置
yaml_config: dict = {}
@lru_cache()
def get_settings() -> Settings:
"""获取并缓存应用配置"""
# 读取server目录下的配置
load_dotenv()
settings = Settings()
yaml_config = file_utils.load_yaml_conf(settings.service_conf)
# 将YAML配置存储到Settings实例中
settings.yaml_config = yaml_config
for k, v in settings.yaml_config.items():
if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default:
setattr(settings, k, v)
if settings.load_yaml:
yaml_config = file_utils.load_yaml_conf(settings.conf_yaml_name)
# 将YAML配置存储到Settings实例中
settings.yaml_config = yaml_config
for k, v in settings.yaml_config.items():
if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default:
setattr(settings, k, v)
return settings

@ -1,10 +1,8 @@
import asyncio
import inspect
from typing import Any
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlmodel import SQLModel
from common.constant import Constant
from common.global_enums import IsDelete
@ -143,24 +141,3 @@ AsyncSessionLocal = async_sessionmaker(
# 关闭引擎
async def close_engine():
await engine.dispose()
# 初始化数据库表(异步执行)
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
if __name__ == '__main__':
import user
async def main():
try:
await init_db()
finally:
await close_engine() # 确保引擎关闭
asyncio.run(main())

@ -0,0 +1,15 @@
from sqlmodel import SQLModel
from entity import DbBaseModel, engine
# 初始化数据库表(异步执行)
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
class User(DbBaseModel, table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str
password: str

@ -1,7 +0,0 @@
from entity.base_entity import DbBaseModel
class User(DbBaseModel,table=True):
__tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写
username: str
password: str

@ -1,13 +0,0 @@
from fastapi import APIRouter
from config import settings
from utils.file_utils import get_project_base_directory
router = APIRouter()
@router.get("/test")
async def test():
return {
"settings":settings,
"base_path": get_project_base_directory(),
}

@ -1,12 +1,14 @@
from typing import Union, Type, List, Any
from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional
from fastapi_pagination import Params
from fastapi_pagination.ext.sqlalchemy import paginate
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from core.global_context import current_session
from entity import DbBaseModel
from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq
from utils import get_uuid
@ -18,10 +20,10 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询)
session.scalar: 直接明确获取一条数据可以直接返回无需额外处理
"""
T = TypeVar('T', bound=DbBaseModel)
class BaseService:
model = None # 子类必须指定模型
class BaseService(Generic[T]):
model: Type[T] # 子类必须指定模型
@classmethod
def get_db(cls) -> AsyncSession:
@ -64,7 +66,7 @@ class BaseService:
pass
@classmethod
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]):
async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -73,7 +75,7 @@ class BaseService:
@classmethod
async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None,
dto_model_class: Type[BaseModel] = None):
dto_model_class: Type[BaseModel] = None)->BasePageResp:
if not query_params:
query_params = {}
if not isinstance(query_params, dict):
@ -114,7 +116,7 @@ class BaseService:
})
@classmethod
async def get_list(cls, query_params: Union[dict, BaseQueryReq]):
async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -128,10 +130,10 @@ class BaseService:
query_stmt = query_stmt.order_by(field.asc())
session = cls.get_db()
exec_result = await session.execute(query_stmt)
return exec_result.scalars().all()
return list(exec_result.scalars().all())
@classmethod
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]:
async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[str]:
if not isinstance(query_params, dict):
query_params = query_params.model_dump()
query_params = {k: v for k, v in query_params.items() if v is not None}
@ -148,7 +150,7 @@ class BaseService:
return [item["id"] for item in exec_result.scalars().all()]
@classmethod
async def save(cls, **kwargs):
async def save(cls, **kwargs)->T:
sample_obj = cls.model(**kwargs)
session = cls.get_db()
session.add(sample_obj)
@ -156,7 +158,7 @@ class BaseService:
return sample_obj
@classmethod
async def insert_many(cls, data_list, batch_size=100):
async def insert_many(cls, data_list, batch_size=100)->None:
async with cls.get_db() as session:
for d in data_list:
if not d.get("id", None):
@ -166,27 +168,27 @@ class BaseService:
session.add_all(data_list[i: i + batch_size])
@classmethod
async def update_by_id(cls, pid, data):
async def update_by_id(cls, pid, data)-> int:
update_stmt = cls.model.update().where(cls.model.id == pid).values(**data)
session = cls.get_db()
result = await session.execute(update_stmt)
return result.rowcount
return result.rowcount()
@classmethod
async def update_many_by_id(cls, data_list):
async def update_many_by_id(cls, data_list)->None:
async with cls.get_db() as session:
for data in data_list:
stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data)
await session.execute(stmt)
@classmethod
async def get_by_id(cls, pid):
async def get_by_id(cls, pid)->T:
stmt = cls.model.select(cls.model.id == pid)
session = cls.get_db()
return await session.scalar(stmt)
@classmethod
async def get_by_ids(cls, pids, cols=None):
async def get_by_ids(cls, pids, cols=None)->List[T]:
if cols:
objs = cls.model.select(*cols)
else:
@ -194,21 +196,21 @@ class BaseService:
stmt = objs.where(cls.model.id.in_(pids))
session = cls.get_db()
result = await session.scalars(stmt)
return result.all()
return list(result.all())
@classmethod
async def delete_by_id(cls, pid):
async def delete_by_id(cls, pid)-> int:
del_stmt = cls.model.delete().where(cls.model.id == pid)
session = cls.get_db()
exec_result = await session.execute(del_stmt)
return exec_result.rowcount
return exec_result.rowcount()
@classmethod
async def delete_by_ids(cls, pids):
async def delete_by_ids(cls, pids)-> int:
session = cls.get_db()
del_stmt = cls.model.delete().where(cls.model.id.in_(pids))
result = await session.execute(del_stmt)
return result.rowcount
return result.rowcount()
@classmethod
async def get_data_count(cls, query_params: dict = None) -> int:
@ -220,5 +222,5 @@ class BaseService:
return await session.scalar(stmt)
@classmethod
async def is_exist(cls, query_params: dict = None):
async def is_exist(cls, query_params: dict = None) -> bool:
return await cls.get_data_count(query_params) > 0

@ -1,7 +1,7 @@
from entity.user import User
from entity.db_models import User
from service.base_service import BaseService
# 5. 具体服务类
class UserService(BaseService):
class UserService(BaseService[User]):
model = User # 指定模型

@ -1,66 +0,0 @@
import asyncio
from sqlalchemy import delete
from sqlmodel import select
from entity import AsyncSessionLocal, close_engine
from entity.user import User
async def add():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
# session = get_db_session()
async with AsyncSessionLocal() as session:
for user in add_list:
session.add(user)
await session.commit()
await close_engine()
print("完成")
async def remove():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
async with AsyncSessionLocal() as session:
for i in add_list:
delete_sql = delete(User).where(User.id == i.id)
await session.execute(delete_sql)
await session.commit()
await close_engine()
print("======================删除完成=================")
async def remove_one():
add_list = []
add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码"))
add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码"))
add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码"))
async with AsyncSessionLocal() as session:
# for i in add_list:
# delete_sql = delete(User).where(User.id == i.id)
user=await session.execute(select(User))
user = user.scalars().first()
session.delete(user)
await session.commit()
await close_engine()
print("======================删除完成=================")
async def get_all():
async with AsyncSessionLocal() as session:
sql = select(User)
result = await session.execute(sql)
print(result.scalars().all())
await session.commit()
await close_engine()
print("====================get_all==================")
def test(aa="ddd"):
print("aa=", aa)
if __name__ == "__main__":
# asyncio.run(add())
asyncio.run(get_all())
# asyncio.run(remove_one())
asyncio.run(remove())
asyncio.run(get_all())
# test(None)
Loading…
Cancel
Save