feat: 实现 RAG SQL生成系统的后端接口和前端页面

- 新增后端 API,使用 FastAPI 框架实现
- 创建前端聊天界面,使用 HTML 和 JavaScript 实现
- 添加环境变量配置文件,支持不同环境的配置
- 实现与 MySQL 和 pgvector 数据库的连接和查询
- 集成大语言模型和向量模型,用于生成 SQL 语句
main
lijiazheng 6 months ago
parent 1025da7b26
commit e4e938d73b

@ -0,0 +1,58 @@
# -------- 应用配置 --------
# 应用运行环境
APP_ENV = 'dev'
# 应用名称
APP_NAME = '规则生成系统'
# 应用代理路径
# APP_ROOT_PATH = '/rule'
APP_ROOT_PATH = ''
# 应用主机
APP_HOST = '0.0.0.0'
# 应用端口
APP_PORT = 9099
# 应用版本
APP_VERSION= '1.0.0'
# 应用是否开启热重载
APP_RELOAD = true
# 应用是否开启IP归属区域查询
APP_IP_LOCATION_QUERY = true
# 应用是否允许账号同时登录
APP_SAME_TIME_LOGIN = true
# 是否允许接口直接执行sql进行测试
APP_TSET_SQL = false
# -------- mysql数据库配置 --------
# 数据库主机
MYSQL_DB_HOST = 'ngsk.tech'
# 数据库端口
MYSQL_DB_PORT = 33306
# 数据库用户名
MYSQL_DB_USERNAME = 'root'
# 数据库密码
MYSQL_DB_PASSWORD = 'ngsk0809cruise'
# 数据库名称
MYSQL_DB_DATABASE = 'data_governance'
# -------- pgvector数据库配置 --------
# 数据库主机
PG_DB_HOST = '192.168.5.30'
# 数据库端口
PG_DB_PORT = 5432
# 数据库用户名
PG_DB_USERNAME = 'myuser'
# 数据库密码
PG_DB_PASSWORD = 'mypassword'
# 数据库名称
PG_DB_DATABASE = 'vectordb'
# -------- 大语言模型配置 --------
llm_base_url = 'http://192.168.5.20:4090/v1'
llm_api_key = 'gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb'
llm_model = 'qwen3-30b-a3b-instruct-2507'
# -------- 向量模型配置 --------
emb_base_url = 'http://192.168.5.20:4090/v1'
emb_api_key = 'gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb'
emb_model = 'bge-m3'

@ -0,0 +1,87 @@
# -------- 应用配置 --------
# 应用运行环境
APP_ENV = 'prod'
# 应用名称
APP_NAME = '电网智能巡航系统'
# 应用代理路径
# APP_ROOT_PATH = '/cruise'
APP_ROOT_PATH = ''
# 应用主机
APP_HOST = '0.0.0.0'
# 应用端口
APP_PORT = 9099
# 应用版本
APP_VERSION= '1.6.1'
# 应用是否开启热重载
APP_RELOAD = true
# 应用是否开启IP归属区域查询
APP_IP_LOCATION_QUERY = true
# 应用是否允许账号同时登录
APP_SAME_TIME_LOGIN = true
# 是否允许接口直接执行sql进行测试
APP_TSET_SQL = true
# -------- Jwt配置 --------
# Jwt秘钥
JWT_SECRET_KEY = 'b01c66dc2c58dc6a0aabfe2144256be36226de378bf87f72c0c795dda67f4d55'
# Jwt算法
JWT_ALGORITHM = 'HS256'
# 令牌过期时间
JWT_EXPIRE_MINUTES = 1440
# redis中令牌过期时间
JWT_REDIS_EXPIRE_MINUTES = 300
# -------- 数据库1配置 --------
# 数据库类型,可选的有'mysql'、'postgresql',默认为'mysql'
DB_TYPE = 'mysql'
# 数据库主机
DB_HOST = '10.92.176.60'
# 数据库端口
DB_PORT = 13306
# 数据库用户名
DB_USERNAME = 'root'
# 数据库密码
DB_PASSWORD = 'root'
# 数据库名称
DB_DATABASE = 'cruise'
# 是否开启sqlalchemy日志
DB_ECHO = true
# 允许溢出连接池大小的最大连接数
DB_MAX_OVERFLOW = 10
# 连接池大小0表示连接数无限制
DB_POOL_SIZE = 50
# 连接回收时间(单位:秒)
DB_POOL_RECYCLE = 3600
# 连接池中没有线程可用时,最多等待的时间(单位:秒)
DB_POOL_TIMEOUT = 30
# -------- 数据库2配置 --------
DB2_TYPE = 'mysql'
# 数据库主机
DB2_HOST = '10.92.176.60'
# 数据库端口
DB2_PORT = 13306
# 数据库用户名
DB2_USERNAME = 'root'
# 数据库密码
DB2_PASSWORD = 'root'
# 数据库名称
DB2_DATABASE = 'operationrisk'
# -------- Redis配置 --------
# Redis主机
REDIS_HOST = '10.92.176.60'
# Redis端口
REDIS_PORT = 16379
# Redis用户名
REDIS_USERNAME = ''
# Redis密码
REDIS_PASSWORD = ''
# Redis数据库
REDIS_DATABASE = 2
AI_BASE_URL = 'http://192.168.196.140:1025/v1'
AI_API_KEY = 'ollama'
AI_MODEL_NAME = 'deepseek-r1:14b'
AI_DAILY_REPORT_URL = 'http://192.168.196.140:19013/sgd/'

@ -0,0 +1,19 @@
# 使用官方 Python 基础镜像
FROM hub.1panel.dev/library/python:3.12-slim
# 设置工作目录
WORKDIR /app
# 将当前目录下的所有文件复制到容器中的 /app 目录
COPY . /app
# 安装依赖
#RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt
# 暴露应用程序运行的端口(如果需要)
EXPOSE 9099
# 启动命令
CMD ["python", "app.py","--env=prod"]

@ -0,0 +1,66 @@
# backend.py
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, APIRouter
from pydantic import BaseModel
import asyncio
from starlette.middleware.cors import CORSMiddleware
from rag import rag_generate_rule
@asynccontextmanager
async def lifespan(app : FastAPI):
# logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}')
print(f'开始启动')
print(f'启动成功')
yield
print(f'关闭成功')
app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语句的后端服务", lifespan=lifespan)
rule_router = APIRouter(prefix="/rule")
# 添加CORS中间件 - 允许所有来源(开发环境)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该指定具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
sql: str
@rule_router.post("/generate_sql", response_model=QueryResponse)
async def generate_sql(request: QueryRequest):
"""
接收用户问题并调用RAG生成SQL语句
- **query**: 用户的自然语言问题
- 返回生成的SQL语句
"""
try:
# 调用rag_generate_rule方法生成SQL
result = await rag_generate_rule(request.query)
return QueryResponse(sql=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}")
@rule_router.get("/health")
async def health_check():
"""
健康检查端点
"""
return {"status": "healthy"}
# 注册带前缀的路由器
app.include_router(rule_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

129
env.py

@ -0,0 +1,129 @@
import argparse
import os
import sys
from dotenv import load_dotenv
from functools import lru_cache
from pydantic import computed_field
from pydantic_settings import BaseSettings
from typing import Literal
class GetConfig:
def __init__(self):
self.parse_cli_args()
@lru_cache()
def get_llm_config(self):
return LlmSettings()
@lru_cache()
def get_emb_config(self):
return EmbSettings()
@lru_cache()
def get_mysql_database_config(self):
"""
获取数据库配置
"""
# 实例化数据库配置模型
return MysqlDataBaseSettings()
@lru_cache()
def get_pgvector_database_config(self):
"""
获取数据库配置
"""
# 实例化数据库配置模型
return PgvectorDataBaseSettings()
@staticmethod
def parse_cli_args():
"""
解析命令行参数
"""
if 'uvicorn' in sys.argv[0]:
# 使用uvicorn启动时命令行参数需要按照uvicorn的文档进行配置无法自定义参数
pass
else:
# 使用argparse定义命令行参数
parser = argparse.ArgumentParser(description='命令行参数')
parser.add_argument('--env', type=str, default='', help='运行环境')
# 解析命令行参数
args = parser.parse_args()
# 设置环境变量如果未设置命令行参数默认APP_ENV为dev
os.environ['APP_ENV'] = args.env if args.env else 'dev'
# 读取运行环境
run_env = os.environ.get('APP_ENV', '')
# 运行环境未指定时默认加载.env.dev
env_file = '.env.dev'
# 运行环境不为空时按命令行参数加载对应.env文件
if run_env != '':
env_file = f'.env.{run_env}'
# 加载配置
load_dotenv(env_file)
print(f'当前运行环境为:{run_env}')
print(os.getenv("DB_HOST"))
# 实例化获取配置类
get_config = GetConfig()
class LlmSettings:
"""
AI配置
"""
base_url: str = os.getenv('llm_base_url', '')
api_key: str = os.getenv('llm_api_key', '')
model_name: str = os.getenv('llm_model', '')
class EmbSettings:
"""
AI配置
"""
base_url: str = os.getenv('emb_base_url', '')
api_key: str = os.getenv('emb_api_key', '')
model_name: str = os.getenv('emb_model', '')
class MysqlDataBaseSettings:
"""
数据库配置
"""
# 数据库配置
db_host: str = os.getenv('MYSQL_DB_HOST', '')
db_port: int = int(os.getenv('MYSQL_DB_PORT', 33306))
db_username: str = os.getenv('MYSQL_DB_USERNAME', '')
db_password: str = os.getenv('MYSQL_DB_PASSWORD', '')
db_database: str = os.getenv('MYSQL_DB_DATABASE', '')
db_echo: bool = True
db_max_overflow: int = 10
db_pool_size: int = 50
db_pool_recycle: int = 3600
db_pool_timeout: int = 30
class PgvectorDataBaseSettings:
"""
数据库配置
"""
# 数据库配置
db_host: str = os.getenv('PG_DB_HOST', '')
db_port: int = os.getenv('PG_DB_PORT', 5432)
db_username: str = os.getenv('PG_DB_USERNAME', '')
db_password: str = os.getenv('PG_DB_PASSWORD', '')
db_database: str = os.getenv('PG_DB_DATABASE', '')
db_echo: bool = True
db_max_overflow: int = 10
db_pool_size: int = 50
db_pool_recycle: int = 3600
db_pool_timeout: int = 30
# mysql数据库配置
MysqlDataBaseConfig = get_config.get_mysql_database_config()
# pgvector数据库配置
PgvectorDataBaseConfig = get_config.get_pgvector_database_config()
LlmBaseConfig = get_config.get_llm_config()
EmbBaseConfig = get_config.get_emb_config()

@ -0,0 +1,224 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>SQL生成聊天界面</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background-color: #f5f5f5;
}
.chat-container {
background-color: white;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
padding: 20px;
margin-bottom: 20px;
height: 500px;
overflow-y: auto;
}
.message {
margin-bottom: 15px;
padding: 10px;
border-radius: 8px;
max-width: 80%;
}
.user-message {
background-color: #dcf8c6;
margin-left: auto;
}
.bot-message {
background-color: #e5e5ea;
margin-right: auto;
}
.input-container {
display: flex;
gap: 10px;
}
#user-input {
flex: 1;
padding: 12px;
border: 1px solid #ddd;
border-radius: 20px;
outline: none;
}
#send-button {
padding: 12px 24px;
background-color: #4CAF50;
color: white;
border: none;
border-radius: 20px;
cursor: pointer;
}
#send-button:disabled {
background-color: #cccccc;
cursor: not-allowed;
}
.loading {
text-align: center;
color: #666;
font-style: italic;
}
.sql-result {
font-family: monospace;
background-color: #f8f8f8;
padding: 10px;
border-left: 4px solid #4CAF50;
white-space: pre-wrap;
overflow-x: auto;
}
.error-message {
color: #f44336;
background-color: #ffebee;
padding: 10px;
border-radius: 5px;
}
</style>
</head>
<body>
<h1>SQL生成聊天助手</h1>
<div class="chat-container" id="chat-container">
<div class="message bot-message">
您好请描述您需要生成SQL的需求我会为您生成相应的SQL语句。
</div>
</div>
<div class="input-container">
<input type="text" id="user-input" placeholder="请输入您的问题..." />
<button id="send-button">发送</button>
</div>
<script>
const chatContainer = document.getElementById('chat-container');
const userInput = document.getElementById('user-input');
const sendButton = document.getElementById('send-button');
// 添加消息到聊天界面
function addMessage(text, isUser = false, isSQL = false, isError = false) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${isUser ? 'user-message' : 'bot-message'}`;
if (isSQL) {
const sqlDiv = document.createElement('div');
sqlDiv.className = 'sql-result';
sqlDiv.textContent = text;
messageDiv.appendChild(sqlDiv);
} else if (isError) {
messageDiv.className = 'error-message';
messageDiv.textContent = text;
} else {
messageDiv.textContent = text;
}
chatContainer.appendChild(messageDiv);
chatContainer.scrollTop = chatContainer.scrollHeight;
}
// 显示加载状态
function showLoading() {
const loadingDiv = document.createElement('div');
loadingDiv.className = 'message bot-message loading';
loadingDiv.id = 'loading-message';
loadingDiv.textContent = '正在生成SQL请稍候...';
chatContainer.appendChild(loadingDiv);
chatContainer.scrollTop = chatContainer.scrollHeight;
// 禁用输入和发送按钮
userInput.disabled = true;
sendButton.disabled = true;
sendButton.textContent = '处理中...';
}
// 隐藏加载状态
function hideLoading() {
const loadingMessage = document.getElementById('loading-message');
if (loadingMessage) {
loadingMessage.remove();
}
// 启用输入和发送按钮
userInput.disabled = false;
sendButton.disabled = false;
sendButton.textContent = '发送';
}
// 调用后端API生成SQL
async function generateSQL(query) {
try {
const response = await fetch('http://localhost:8000/rule/generate_sql', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ query: query })
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || '请求失败');
}
const data = await response.json();
return data.sql;
} catch (error) {
throw new Error(`API调用失败: ${error.message}`);
}
}
// 处理用户发送消息
async function handleSendMessage() {
const query = userInput.value.trim();
if (!query) return;
// 添加用户消息
addMessage(query, true);
userInput.value = '';
// 显示加载状态
showLoading();
try {
// 调用后端生成SQL
const sqlResult = await generateSQL(query);
// 隐藏加载状态
hideLoading();
// 显示SQL结果
addMessage(sqlResult, false, true);
} catch (error) {
// 隐藏加载状态
hideLoading();
// 显示错误信息
addMessage(`错误: ${error.message}`, false, false, true);
}
}
// 绑定发送按钮点击事件
sendButton.addEventListener('click', handleSendMessage);
// 绑定回车键发送事件
userInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') {
handleSendMessage();
}
});
</script>
</body>
</html>

@ -0,0 +1,9 @@
aiomysql==0.2.0
fastapi==0.116.1
numpy==2.3.2
openai==1.99.6
pandas==2.3.1
psycopg2_binary==2.9.10
pydantic==2.11.7
starlette==0.47.2
uvicorn==0.35.0

@ -2,6 +2,9 @@ import asyncio
import aiomysql
from typing import List, Dict, Any, Optional
from env import MysqlDataBaseConfig
class AsyncMySQLClient:
def __init__(self, host: str, port: int, user: str, password: str, db: str):
"""
@ -142,6 +145,8 @@ async def search_desc_by_table_names(table_name_list : list, db_client: AsyncMyS
})
return ans
except Exception as e:
print(f"查询失败: {e}")
finally:
# 关闭连接
await db_client.close()
@ -186,11 +191,11 @@ async def insert(data : list[dict[str, Any]], db_client: AsyncMySQLClient):
# 创建数据库客户端实例
db_client = AsyncMySQLClient(
host='ngsk.tech',
port=33306,
user='root',
password='ngsk0809cruise',
db='data_governance'
host=MysqlDataBaseConfig.db_host,
port=MysqlDataBaseConfig.db_port,
user=MysqlDataBaseConfig.db_username,
password=MysqlDataBaseConfig.db_password,
db=MysqlDataBaseConfig.db_database
)
async def get_db():
# 创建数据库客户端实例

@ -1,14 +1,20 @@
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key="gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb",
base_url="http://192.168.5.20:4090/v1",
from env import LlmBaseConfig, EmbBaseConfig
llm_client = AsyncOpenAI(
api_key=LlmBaseConfig.api_key,
base_url=LlmBaseConfig.base_url,
)
emb_client = AsyncOpenAI(
api_key=EmbBaseConfig.api_key,
base_url=EmbBaseConfig.base_url,
)
async def generation_rule(prompt):
response = await client.chat.completions.create(
model="qwen3-30b-a3b-instruct-2507",
response = await llm_client.chat.completions.create(
model=LlmBaseConfig.model_name,
messages=prompt,
n = 1,
stream = False,
@ -22,8 +28,8 @@ async def generation_rule(prompt):
return response.choices[0].message.content
async def generation_vector(text):
response = await client.embeddings.create(
model="bge-m3", # 替换为实际的向量模型名称
response = await emb_client.embeddings.create(
model=EmbBaseConfig.model_name, # 替换为实际的向量模型名称
input=text,
encoding_format="float"
)
@ -52,8 +58,8 @@ async def rerank_documents(query, documents, top_n=None):
}
]
response = await client.chat.completions.create(
model="qwen3-30b-a3b-instruct-2507", # 使用已知可用的模型
response = await llm_client.chat.completions.create(
model=LlmBaseConfig.model_name, # 使用已知可用的模型
messages=messages,
temperature=0.0,
max_tokens=100

@ -2,13 +2,15 @@ import psycopg2
from psycopg2.extras import execute_values
import numpy as np
from env import PgvectorDataBaseConfig
# 数据库连接配置
DB_CONFIG = {
"host": "192.168.5.30",
"database": "vectordb",
"user": "myuser",
"password": "mypassword",
"port": 5432
"host": PgvectorDataBaseConfig.db_host,
"database": PgvectorDataBaseConfig.db_database,
"user": PgvectorDataBaseConfig.db_username,
"password": PgvectorDataBaseConfig.db_password,
"port": PgvectorDataBaseConfig.db_port
}

Loading…
Cancel
Save