# backend.py from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request, APIRouter from pydantic import BaseModel import asyncio from starlette.middleware.cors import CORSMiddleware from middlewares.handle import handle_middleware from rag import rag_generate_rule from util.log_util import logger @asynccontextmanager async def lifespan(app : FastAPI): # logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}') print(f'开始启动') yield print(f'关闭成功') app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语句的后端服务", lifespan=lifespan) rule_router = APIRouter(prefix="/rule") # 添加CORS中间件 - 允许所有来源(开发环境) # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # 在生产环境中应该指定具体的域名 # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # 加载中间件 handle_middleware(app) class QueryRequest(BaseModel): query: str class QueryResponse(BaseModel): sql: str @rule_router.post("/generate_sql", response_model=QueryResponse) async def generate_sql(request: QueryRequest): """ 接收用户问题并调用RAG生成SQL语句 - **query**: 用户的自然语言问题 - 返回生成的SQL语句 """ try: # 调用rag_generate_rule方法生成SQL result = await rag_generate_rule(request.query) return QueryResponse(sql=result) except Exception as e: # print(e) # print(e.__context__) logger.info(e) logger.info(e.__context__) raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}") @rule_router.get("/health") async def health_check(): """ 健康检查端点 """ return {"status": "healthy"} # 注册带前缀的路由器 app.include_router(rule_router) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)