You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
1.8 KiB
Python

# backend.py
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, APIRouter
from pydantic import BaseModel
import asyncio
from starlette.middleware.cors import CORSMiddleware
from rag import rag_generate_rule
@asynccontextmanager
async def lifespan(app : FastAPI):
# logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}')
print(f'开始启动')
print(f'启动成功')
yield
print(f'关闭成功')
app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语句的后端服务", lifespan=lifespan)
rule_router = APIRouter(prefix="/rule")
# 添加CORS中间件 - 允许所有来源(开发环境)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该指定具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
sql: str
@rule_router.post("/generate_sql", response_model=QueryResponse)
async def generate_sql(request: QueryRequest):
"""
接收用户问题并调用RAG生成SQL语句
- **query**: 用户的自然语言问题
- 返回生成的SQL语句
"""
try:
# 调用rag_generate_rule方法生成SQL
result = await rag_generate_rule(request.query)
return QueryResponse(sql=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}")
@rule_router.get("/health")
async def health_check():
"""
健康检查端点
"""
return {"status": "healthy"}
# 注册带前缀的路由器
app.include_router(rule_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)