You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
# backend.py
|
|
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, HTTPException, Request, APIRouter
|
|
from pydantic import BaseModel
|
|
import asyncio
|
|
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
from rag import rag_generate_rule
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app : FastAPI):
|
|
# logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}')
|
|
print(f'开始启动')
|
|
print(f'启动成功')
|
|
yield
|
|
print(f'关闭成功')
|
|
|
|
app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语句的后端服务", lifespan=lifespan)
|
|
|
|
rule_router = APIRouter(prefix="/rule")
|
|
|
|
# 添加CORS中间件 - 允许所有来源(开发环境)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # 在生产环境中应该指定具体的域名
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
|
|
class QueryResponse(BaseModel):
|
|
sql: str
|
|
|
|
@rule_router.post("/generate_sql", response_model=QueryResponse)
|
|
async def generate_sql(request: QueryRequest):
|
|
"""
|
|
接收用户问题并调用RAG生成SQL语句
|
|
|
|
- **query**: 用户的自然语言问题
|
|
- 返回生成的SQL语句
|
|
"""
|
|
try:
|
|
# 调用rag_generate_rule方法生成SQL
|
|
result = await rag_generate_rule(request.query)
|
|
return QueryResponse(sql=result)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}")
|
|
|
|
@rule_router.get("/health")
|
|
async def health_check():
|
|
"""
|
|
健康检查端点
|
|
"""
|
|
return {"status": "healthy"}
|
|
|
|
|
|
# 注册带前缀的路由器
|
|
app.include_router(rule_router)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|