diff --git a/.env.template b/.env.template index eb3b0b7..482cca5 100644 --- a/.env.template +++ b/.env.template @@ -1,3 +1,5 @@ -SERVICE_CONF=application.yaml +CONF_YAML_NAME=application.yaml +LOAD_YAML=true DEBUG=true + DATABASE_URL=mysql+aiomysql://username:password@host:port/database_name \ No newline at end of file diff --git a/README.md b/README.md index 23ec279..ee86805 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ uv add xxx ```bash # 复制环境变量模板并修改 cp .env.template .env +# 如需启用YAML 配置文件支持,需要将.env中 的LOAD_YAML设置为 true,并执行以下操作 cp application-template.yaml application.yaml # 根据需要编辑.env和application.yaml文件 ``` @@ -72,7 +73,6 @@ python main.py ## API路由 - `/monitor` - 系统监控相关API -- `/test` - 测试相关API - `/user` - 用户相关API ## 开发指南 diff --git a/common/global_enums.py b/common/global_enums.py index cbb0c69..6a5c351 100644 --- a/common/global_enums.py +++ b/common/global_enums.py @@ -1,6 +1,15 @@ -from enum import IntEnum +from enum import IntEnum, StrEnum + class IsDelete(IntEnum): NO_DELETE = 0 DELETE = 1 + +class LLMType(StrEnum): + CHAT = 'chat' + EMBEDDING = 'embedding' + SPEECH2TEXT = 'speech2text' + IMAGE2TEXT = 'image2text' + RERANK = 'rerank' + TTS = 'tts' \ No newline at end of file diff --git a/config/fastapi_config.py b/config/fastapi_config.py index 82b4397..649ddcd 100644 --- a/config/fastapi_config.py +++ b/config/fastapi_config.py @@ -1,4 +1,3 @@ -import asyncio import sys from contextlib import asynccontextmanager from importlib.util import module_from_spec, spec_from_file_location @@ -13,7 +12,8 @@ from config import settings __all__ = ["app"] -from entity import init_db, close_engine +from entity import close_engine +from entity.db_models import init_db from exceptions.global_exc import configure_exception @@ -32,9 +32,10 @@ async def lifespan(app: FastAPI): # 2. 把40个线程改成80 limiter.total_tokens = 80 await init_db() - yield # 上面是启动时做的操作,下面是关闭时做的操作 + yield # 上面是启动时做的操作,下面是关闭时做的操作 await close_engine() + # FastAPI应用初始化 app = FastAPI( title="Fast API", @@ -113,4 +114,3 @@ client_urls_prefix = [ for pages_dir in pages_dirs for path in search_pages_path(pages_dir) ] - diff --git a/config/settings.py b/config/settings.py index 2d9675d..7f8beb0 100644 --- a/config/settings.py +++ b/config/settings.py @@ -14,6 +14,7 @@ class BaseSettings(Base): env_file_encoding = 'utf-8' extra = 'allow' + class Settings(BaseSettings): """应用配置 server目录为后端项目根目录, 在该目录下创建 "config.env" 文件, 写入环境变量(默认大写)会自动加载, 并覆盖同名配置(小写) @@ -25,7 +26,8 @@ class Settings(BaseSettings): # 模式 mode: str = 'dev' # dev, prod debug: bool = False # dev, prod - service_conf: str = 'application.yaml' # dev, prod + load_yaml: bool = True # 是否开启加载 yaml 配置文件 + conf_yaml_name: str = 'application.yaml' # dev, prod # 版本 api_version: str = '/v1' # 时区 @@ -35,23 +37,26 @@ class Settings(BaseSettings): # Redis键前缀 redis_prefix: str = 'agent:' # 当前域名 - host_ip:str = '0.0.0.0' + host_ip: str = '0.0.0.0' host_port: int = 8080 - #sql驱动连接 + # sql驱动连接 database_url: str = '' + # yaml配置 yaml_config: dict = {} + @lru_cache() def get_settings() -> Settings: """获取并缓存应用配置""" # 读取server目录下的配置 load_dotenv() settings = Settings() - yaml_config = file_utils.load_yaml_conf(settings.service_conf) - # 将YAML配置存储到Settings实例中 - settings.yaml_config = yaml_config - for k, v in settings.yaml_config.items(): - if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default: - setattr(settings, k, v) + if settings.load_yaml: + yaml_config = file_utils.load_yaml_conf(settings.conf_yaml_name) + # 将YAML配置存储到Settings实例中 + settings.yaml_config = yaml_config + for k, v in settings.yaml_config.items(): + if not hasattr(settings, k) or getattr(settings, k) == settings.__fields__[k].default: + setattr(settings, k, v) return settings diff --git a/entity/__init__.py b/entity/__init__.py index ba39f26..adbd942 100644 --- a/entity/__init__.py +++ b/entity/__init__.py @@ -1,10 +1,8 @@ -import asyncio import inspect from typing import Any from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession -from sqlmodel import SQLModel from common.constant import Constant from common.global_enums import IsDelete @@ -143,24 +141,3 @@ AsyncSessionLocal = async_sessionmaker( # 关闭引擎 async def close_engine(): await engine.dispose() - - -# 初始化数据库表(异步执行) -async def init_db(): - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - - -if __name__ == '__main__': - import user - - - async def main(): - try: - await init_db() - finally: - await close_engine() # 确保引擎关闭 - - - asyncio.run(main()) diff --git a/entity/db_models.py b/entity/db_models.py new file mode 100644 index 0000000..920e426 --- /dev/null +++ b/entity/db_models.py @@ -0,0 +1,15 @@ +from sqlmodel import SQLModel + +from entity import DbBaseModel, engine + + +# 初始化数据库表(异步执行) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + +class User(DbBaseModel, table=True): + __tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写 + username: str + password: str diff --git a/entity/user.py b/entity/user.py deleted file mode 100644 index 57e6069..0000000 --- a/entity/user.py +++ /dev/null @@ -1,7 +0,0 @@ -from entity.base_entity import DbBaseModel - - -class User(DbBaseModel,table=True): - __tablename__ = "user" # 可以显式指定数据库表名,默认实体名转小写 - username: str - password: str diff --git a/router/monitor.py b/router/monitor_app.py similarity index 100% rename from router/monitor.py rename to router/monitor_app.py diff --git a/router/test.py b/router/test.py deleted file mode 100644 index 69ba9db..0000000 --- a/router/test.py +++ /dev/null @@ -1,13 +0,0 @@ -from fastapi import APIRouter - -from config import settings -from utils.file_utils import get_project_base_directory - -router = APIRouter() - -@router.get("/test") -async def test(): - return { - "settings":settings, - "base_path": get_project_base_directory(), - } \ No newline at end of file diff --git a/service/base_service.py b/service/base_service.py index 82b5416..66d3f45 100644 --- a/service/base_service.py +++ b/service/base_service.py @@ -1,12 +1,14 @@ -from typing import Union, Type, List, Any +from typing import Union, Type, List, Any, TypeVar, Generic, Callable, Coroutine, Optional from fastapi_pagination import Params from fastapi_pagination.ext.sqlalchemy import paginate from pydantic import BaseModel from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import SQLModel from core.global_context import current_session +from entity import DbBaseModel from entity.dto.base import BasePageQueryReq, BasePageResp, BaseQueryReq from utils import get_uuid @@ -18,10 +20,10 @@ session.scalars: 只适合单模型查询(不适合指定列或连表查询) session.scalar: 直接明确获取一条数据,可以直接返回,无需额外处理 """ +T = TypeVar('T', bound=DbBaseModel) - -class BaseService: - model = None # 子类必须指定模型 +class BaseService(Generic[T]): + model: Type[T] # 子类必须指定模型 @classmethod def get_db(cls) -> AsyncSession: @@ -64,7 +66,7 @@ class BaseService: pass @classmethod - async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq]): + async def get_by_page(cls, query_params: Union[dict, BasePageQueryReq])->BasePageResp: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -73,7 +75,7 @@ class BaseService: @classmethod async def auto_page(cls, query_stmt, query_params: Union[dict, BasePageQueryReq] = None, - dto_model_class: Type[BaseModel] = None): + dto_model_class: Type[BaseModel] = None)->BasePageResp: if not query_params: query_params = {} if not isinstance(query_params, dict): @@ -114,7 +116,7 @@ class BaseService: }) @classmethod - async def get_list(cls, query_params: Union[dict, BaseQueryReq]): + async def get_list(cls, query_params: Union[dict, BaseQueryReq])->List[T]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -128,10 +130,10 @@ class BaseService: query_stmt = query_stmt.order_by(field.asc()) session = cls.get_db() exec_result = await session.execute(query_stmt) - return exec_result.scalars().all() + return list(exec_result.scalars().all()) @classmethod - async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[Any]: + async def get_id_list(cls, query_params: Union[dict, BaseQueryReq]) -> List[str]: if not isinstance(query_params, dict): query_params = query_params.model_dump() query_params = {k: v for k, v in query_params.items() if v is not None} @@ -148,7 +150,7 @@ class BaseService: return [item["id"] for item in exec_result.scalars().all()] @classmethod - async def save(cls, **kwargs): + async def save(cls, **kwargs)->T: sample_obj = cls.model(**kwargs) session = cls.get_db() session.add(sample_obj) @@ -156,7 +158,7 @@ class BaseService: return sample_obj @classmethod - async def insert_many(cls, data_list, batch_size=100): + async def insert_many(cls, data_list, batch_size=100)->None: async with cls.get_db() as session: for d in data_list: if not d.get("id", None): @@ -166,27 +168,27 @@ class BaseService: session.add_all(data_list[i: i + batch_size]) @classmethod - async def update_by_id(cls, pid, data): + async def update_by_id(cls, pid, data)-> int: update_stmt = cls.model.update().where(cls.model.id == pid).values(**data) session = cls.get_db() result = await session.execute(update_stmt) - return result.rowcount + return result.rowcount() @classmethod - async def update_many_by_id(cls, data_list): + async def update_many_by_id(cls, data_list)->None: async with cls.get_db() as session: for data in data_list: stmt = cls.model.update().where(cls.model.id == data["id"]).values(**data) await session.execute(stmt) @classmethod - async def get_by_id(cls, pid): + async def get_by_id(cls, pid)->T: stmt = cls.model.select(cls.model.id == pid) session = cls.get_db() return await session.scalar(stmt) @classmethod - async def get_by_ids(cls, pids, cols=None): + async def get_by_ids(cls, pids, cols=None)->List[T]: if cols: objs = cls.model.select(*cols) else: @@ -194,21 +196,21 @@ class BaseService: stmt = objs.where(cls.model.id.in_(pids)) session = cls.get_db() result = await session.scalars(stmt) - return result.all() + return list(result.all()) @classmethod - async def delete_by_id(cls, pid): + async def delete_by_id(cls, pid)-> int: del_stmt = cls.model.delete().where(cls.model.id == pid) session = cls.get_db() exec_result = await session.execute(del_stmt) - return exec_result.rowcount + return exec_result.rowcount() @classmethod - async def delete_by_ids(cls, pids): + async def delete_by_ids(cls, pids)-> int: session = cls.get_db() del_stmt = cls.model.delete().where(cls.model.id.in_(pids)) result = await session.execute(del_stmt) - return result.rowcount + return result.rowcount() @classmethod async def get_data_count(cls, query_params: dict = None) -> int: @@ -220,5 +222,5 @@ class BaseService: return await session.scalar(stmt) @classmethod - async def is_exist(cls, query_params: dict = None): + async def is_exist(cls, query_params: dict = None) -> bool: return await cls.get_data_count(query_params) > 0 diff --git a/service/user_service.py b/service/user_service.py index bc1cb81..a3169cc 100644 --- a/service/user_service.py +++ b/service/user_service.py @@ -1,7 +1,7 @@ -from entity.user import User +from entity.db_models import User from service.base_service import BaseService # 5. 具体服务类 -class UserService(BaseService): +class UserService(BaseService[User]): model = User # 指定模型 diff --git a/test.py b/test.py deleted file mode 100644 index e9c81f9..0000000 --- a/test.py +++ /dev/null @@ -1,66 +0,0 @@ -import asyncio - -from sqlalchemy import delete -from sqlmodel import select - -from entity import AsyncSessionLocal, close_engine -from entity.user import User - - -async def add(): - add_list = [] - add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码")) - add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码")) - add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码")) - # session = get_db_session() - async with AsyncSessionLocal() as session: - for user in add_list: - session.add(user) - await session.commit() - await close_engine() - print("完成") - - -async def remove(): - add_list = [] - add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码")) - add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码")) - add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码")) - async with AsyncSessionLocal() as session: - for i in add_list: - delete_sql = delete(User).where(User.id == i.id) - await session.execute(delete_sql) - await session.commit() - await close_engine() - print("======================删除完成=================") -async def remove_one(): - add_list = [] - add_list.append(User(id="a460359e960311f09677c922f415afd9", username="u1", password="密码")) - add_list.append(User(id="a460416a960311f09677c922f415afd9", username="u2", password="密码")) - add_list.append(User(id="a46042be960311f09677c922f415afd9", username="u3", password="密码")) - async with AsyncSessionLocal() as session: - # for i in add_list: - # delete_sql = delete(User).where(User.id == i.id) - user=await session.execute(select(User)) - user = user.scalars().first() - session.delete(user) - await session.commit() - await close_engine() - print("======================删除完成=================") -async def get_all(): - async with AsyncSessionLocal() as session: - sql = select(User) - result = await session.execute(sql) - print(result.scalars().all()) - await session.commit() - await close_engine() - print("====================get_all==================") -def test(aa="ddd"): - print("aa=", aa) -if __name__ == "__main__": - # asyncio.run(add()) - asyncio.run(get_all()) - # asyncio.run(remove_one()) - asyncio.run(remove()) - asyncio.run(get_all()) - # test(None)