From e4e938d73ba070b51d5fc5a24999219b14e4bae4 Mon Sep 17 00:00:00 2001 From: lijiazheng Date: Mon, 11 Aug 2025 15:32:47 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20RAG=20SQL=E7=94=9F?= =?UTF-8?q?=E6=88=90=E7=B3=BB=E7=BB=9F=E7=9A=84=E5=90=8E=E7=AB=AF=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=92=8C=E5=89=8D=E7=AB=AF=E9=A1=B5=E9=9D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增后端 API,使用 FastAPI 框架实现 - 创建前端聊天界面,使用 HTML 和 JavaScript 实现 - 添加环境变量配置文件,支持不同环境的配置 - 实现与 MySQL 和 pgvector 数据库的连接和查询 - 集成大语言模型和向量模型,用于生成 SQL 语句 --- .env.dev | 58 +++++++++++ .env.prod | 87 +++++++++++++++++ Dockerfile | 19 ++++ backend.py | 66 +++++++++++++ env.py | 129 +++++++++++++++++++++++++ index.html | 224 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 9 ++ util/use_mysql.py | 15 ++- util/use_opanai.py | 24 +++-- util/use_pgvector.py | 12 ++- 10 files changed, 624 insertions(+), 19 deletions(-) create mode 100644 .env.dev create mode 100644 .env.prod create mode 100644 Dockerfile create mode 100644 backend.py create mode 100644 env.py create mode 100644 index.html create mode 100644 requirements.txt diff --git a/.env.dev b/.env.dev new file mode 100644 index 0000000..24f5ded --- /dev/null +++ b/.env.dev @@ -0,0 +1,58 @@ +# -------- 应用配置 -------- +# 应用运行环境 +APP_ENV = 'dev' +# 应用名称 +APP_NAME = '规则生成系统' +# 应用代理路径 +# APP_ROOT_PATH = '/rule' +APP_ROOT_PATH = '' +# 应用主机 +APP_HOST = '0.0.0.0' +# 应用端口 +APP_PORT = 9099 +# 应用版本 +APP_VERSION= '1.0.0' +# 应用是否开启热重载 +APP_RELOAD = true +# 应用是否开启IP归属区域查询 +APP_IP_LOCATION_QUERY = true +# 应用是否允许账号同时登录 +APP_SAME_TIME_LOGIN = true +# 是否允许接口直接执行sql进行测试 +APP_TSET_SQL = false + +# -------- mysql数据库配置 -------- +# 数据库主机 +MYSQL_DB_HOST = 'ngsk.tech' +# 数据库端口 +MYSQL_DB_PORT = 33306 +# 数据库用户名 +MYSQL_DB_USERNAME = 'root' +# 数据库密码 +MYSQL_DB_PASSWORD = 'ngsk0809cruise' +# 数据库名称 +MYSQL_DB_DATABASE = 'data_governance' + + +# -------- pgvector数据库配置 -------- +# 数据库主机 +PG_DB_HOST = '192.168.5.30' +# 数据库端口 +PG_DB_PORT = 5432 +# 数据库用户名 +PG_DB_USERNAME = 'myuser' +# 数据库密码 +PG_DB_PASSWORD = 'mypassword' +# 数据库名称 +PG_DB_DATABASE = 'vectordb' + +# -------- 大语言模型配置 -------- +llm_base_url = 'http://192.168.5.20:4090/v1' +llm_api_key = 'gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb' +llm_model = 'qwen3-30b-a3b-instruct-2507' + +# -------- 向量模型配置 -------- +emb_base_url = 'http://192.168.5.20:4090/v1' +emb_api_key = 'gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb' +emb_model = 'bge-m3' + diff --git a/.env.prod b/.env.prod new file mode 100644 index 0000000..5aa584e --- /dev/null +++ b/.env.prod @@ -0,0 +1,87 @@ +# -------- 应用配置 -------- +# 应用运行环境 +APP_ENV = 'prod' +# 应用名称 +APP_NAME = '电网智能巡航系统' +# 应用代理路径 +# APP_ROOT_PATH = '/cruise' +APP_ROOT_PATH = '' +# 应用主机 +APP_HOST = '0.0.0.0' +# 应用端口 +APP_PORT = 9099 +# 应用版本 +APP_VERSION= '1.6.1' +# 应用是否开启热重载 +APP_RELOAD = true +# 应用是否开启IP归属区域查询 +APP_IP_LOCATION_QUERY = true +# 应用是否允许账号同时登录 +APP_SAME_TIME_LOGIN = true +# 是否允许接口直接执行sql进行测试 +APP_TSET_SQL = true + +# -------- Jwt配置 -------- +# Jwt秘钥 +JWT_SECRET_KEY = 'b01c66dc2c58dc6a0aabfe2144256be36226de378bf87f72c0c795dda67f4d55' +# Jwt算法 +JWT_ALGORITHM = 'HS256' +# 令牌过期时间 +JWT_EXPIRE_MINUTES = 1440 +# redis中令牌过期时间 +JWT_REDIS_EXPIRE_MINUTES = 300 + + +# -------- 数据库1配置 -------- +# 数据库类型,可选的有'mysql'、'postgresql',默认为'mysql' +DB_TYPE = 'mysql' +# 数据库主机 +DB_HOST = '10.92.176.60' +# 数据库端口 +DB_PORT = 13306 +# 数据库用户名 +DB_USERNAME = 'root' +# 数据库密码 +DB_PASSWORD = 'root' +# 数据库名称 +DB_DATABASE = 'cruise' +# 是否开启sqlalchemy日志 +DB_ECHO = true +# 允许溢出连接池大小的最大连接数 +DB_MAX_OVERFLOW = 10 +# 连接池大小,0表示连接数无限制 +DB_POOL_SIZE = 50 +# 连接回收时间(单位:秒) +DB_POOL_RECYCLE = 3600 +# 连接池中没有线程可用时,最多等待的时间(单位:秒) +DB_POOL_TIMEOUT = 30 + +# -------- 数据库2配置 -------- +DB2_TYPE = 'mysql' +# 数据库主机 +DB2_HOST = '10.92.176.60' +# 数据库端口 +DB2_PORT = 13306 +# 数据库用户名 +DB2_USERNAME = 'root' +# 数据库密码 +DB2_PASSWORD = 'root' +# 数据库名称 +DB2_DATABASE = 'operationrisk' + +# -------- Redis配置 -------- +# Redis主机 +REDIS_HOST = '10.92.176.60' +# Redis端口 +REDIS_PORT = 16379 +# Redis用户名 +REDIS_USERNAME = '' +# Redis密码 +REDIS_PASSWORD = '' +# Redis数据库 +REDIS_DATABASE = 2 + +AI_BASE_URL = 'http://192.168.196.140:1025/v1' +AI_API_KEY = 'ollama' +AI_MODEL_NAME = 'deepseek-r1:14b' +AI_DAILY_REPORT_URL = 'http://192.168.196.140:19013/sgd/' \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..721f4fa --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +# 使用官方 Python 基础镜像 +FROM hub.1panel.dev/library/python:3.12-slim + +# 设置工作目录 +WORKDIR /app + +# 将当前目录下的所有文件复制到容器中的 /app 目录 +COPY . /app + +# 安装依赖 +#RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt +RUN pip install --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt + + +# 暴露应用程序运行的端口(如果需要) +EXPOSE 9099 + +# 启动命令 +CMD ["python", "app.py","--env=prod"] \ No newline at end of file diff --git a/backend.py b/backend.py new file mode 100644 index 0000000..486f474 --- /dev/null +++ b/backend.py @@ -0,0 +1,66 @@ +# backend.py +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException, Request, APIRouter +from pydantic import BaseModel +import asyncio + +from starlette.middleware.cors import CORSMiddleware + +from rag import rag_generate_rule + +@asynccontextmanager +async def lifespan(app : FastAPI): + # logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}') + print(f'开始启动') + print(f'启动成功') + yield + print(f'关闭成功') + +app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语句的后端服务", lifespan=lifespan) + +rule_router = APIRouter(prefix="/rule") + +# 添加CORS中间件 - 允许所有来源(开发环境) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 在生产环境中应该指定具体的域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +class QueryRequest(BaseModel): + query: str + +class QueryResponse(BaseModel): + sql: str + +@rule_router.post("/generate_sql", response_model=QueryResponse) +async def generate_sql(request: QueryRequest): + """ + 接收用户问题并调用RAG生成SQL语句 + + - **query**: 用户的自然语言问题 + - 返回生成的SQL语句 + """ + try: + # 调用rag_generate_rule方法生成SQL + result = await rag_generate_rule(request.query) + return QueryResponse(sql=result) + except Exception as e: + raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}") + +@rule_router.get("/health") +async def health_check(): + """ + 健康检查端点 + """ + return {"status": "healthy"} + + +# 注册带前缀的路由器 +app.include_router(rule_router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/env.py b/env.py new file mode 100644 index 0000000..92991fb --- /dev/null +++ b/env.py @@ -0,0 +1,129 @@ +import argparse +import os +import sys +from dotenv import load_dotenv +from functools import lru_cache +from pydantic import computed_field +from pydantic_settings import BaseSettings +from typing import Literal + + + +class GetConfig: + + def __init__(self): + self.parse_cli_args() + + @lru_cache() + def get_llm_config(self): + return LlmSettings() + + @lru_cache() + def get_emb_config(self): + return EmbSettings() + + @lru_cache() + def get_mysql_database_config(self): + """ + 获取数据库配置 + """ + # 实例化数据库配置模型 + return MysqlDataBaseSettings() + + @lru_cache() + def get_pgvector_database_config(self): + """ + 获取数据库配置 + """ + # 实例化数据库配置模型 + return PgvectorDataBaseSettings() + + @staticmethod + def parse_cli_args(): + """ + 解析命令行参数 + """ + if 'uvicorn' in sys.argv[0]: + # 使用uvicorn启动时,命令行参数需要按照uvicorn的文档进行配置,无法自定义参数 + pass + else: + # 使用argparse定义命令行参数 + parser = argparse.ArgumentParser(description='命令行参数') + parser.add_argument('--env', type=str, default='', help='运行环境') + # 解析命令行参数 + args = parser.parse_args() + # 设置环境变量,如果未设置命令行参数,默认APP_ENV为dev + os.environ['APP_ENV'] = args.env if args.env else 'dev' + # 读取运行环境 + run_env = os.environ.get('APP_ENV', '') + # 运行环境未指定时默认加载.env.dev + env_file = '.env.dev' + # 运行环境不为空时按命令行参数加载对应.env文件 + if run_env != '': + env_file = f'.env.{run_env}' + # 加载配置 + load_dotenv(env_file) + print(f'当前运行环境为:{run_env}') + print(os.getenv("DB_HOST")) + +# 实例化获取配置类 +get_config = GetConfig() + +class LlmSettings: + """ + AI配置 + """ + base_url: str = os.getenv('llm_base_url', '') + api_key: str = os.getenv('llm_api_key', '') + model_name: str = os.getenv('llm_model', '') + +class EmbSettings: + """ + AI配置 + """ + base_url: str = os.getenv('emb_base_url', '') + api_key: str = os.getenv('emb_api_key', '') + model_name: str = os.getenv('emb_model', '') + +class MysqlDataBaseSettings: + """ + 数据库配置 + """ + # 数据库配置 + db_host: str = os.getenv('MYSQL_DB_HOST', '') + db_port: int = int(os.getenv('MYSQL_DB_PORT', 33306)) + db_username: str = os.getenv('MYSQL_DB_USERNAME', '') + db_password: str = os.getenv('MYSQL_DB_PASSWORD', '') + db_database: str = os.getenv('MYSQL_DB_DATABASE', '') + + db_echo: bool = True + db_max_overflow: int = 10 + db_pool_size: int = 50 + db_pool_recycle: int = 3600 + db_pool_timeout: int = 30 + +class PgvectorDataBaseSettings: + """ + 数据库配置 + """ + # 数据库配置 + db_host: str = os.getenv('PG_DB_HOST', '') + db_port: int = os.getenv('PG_DB_PORT', 5432) + db_username: str = os.getenv('PG_DB_USERNAME', '') + db_password: str = os.getenv('PG_DB_PASSWORD', '') + db_database: str = os.getenv('PG_DB_DATABASE', '') + + db_echo: bool = True + db_max_overflow: int = 10 + db_pool_size: int = 50 + db_pool_recycle: int = 3600 + db_pool_timeout: int = 30 + + +# mysql数据库配置 +MysqlDataBaseConfig = get_config.get_mysql_database_config() +# pgvector数据库配置 +PgvectorDataBaseConfig = get_config.get_pgvector_database_config() + +LlmBaseConfig = get_config.get_llm_config() +EmbBaseConfig = get_config.get_emb_config() \ No newline at end of file diff --git a/index.html b/index.html new file mode 100644 index 0000000..1725094 --- /dev/null +++ b/index.html @@ -0,0 +1,224 @@ + + + + + + SQL生成聊天界面 + + + +

SQL生成聊天助手

+
+
+ 您好!请描述您需要生成SQL的需求,我会为您生成相应的SQL语句。 +
+
+ +
+ + +
+ + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1032ed5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +aiomysql==0.2.0 +fastapi==0.116.1 +numpy==2.3.2 +openai==1.99.6 +pandas==2.3.1 +psycopg2_binary==2.9.10 +pydantic==2.11.7 +starlette==0.47.2 +uvicorn==0.35.0 diff --git a/util/use_mysql.py b/util/use_mysql.py index f95fdf0..094fd6e 100644 --- a/util/use_mysql.py +++ b/util/use_mysql.py @@ -2,6 +2,9 @@ import asyncio import aiomysql from typing import List, Dict, Any, Optional +from env import MysqlDataBaseConfig + + class AsyncMySQLClient: def __init__(self, host: str, port: int, user: str, password: str, db: str): """ @@ -142,6 +145,8 @@ async def search_desc_by_table_names(table_name_list : list, db_client: AsyncMyS }) return ans + except Exception as e: + print(f"查询失败: {e}") finally: # 关闭连接 await db_client.close() @@ -186,11 +191,11 @@ async def insert(data : list[dict[str, Any]], db_client: AsyncMySQLClient): # 创建数据库客户端实例 db_client = AsyncMySQLClient( - host='ngsk.tech', - port=33306, - user='root', - password='ngsk0809cruise', - db='data_governance' + host=MysqlDataBaseConfig.db_host, + port=MysqlDataBaseConfig.db_port, + user=MysqlDataBaseConfig.db_username, + password=MysqlDataBaseConfig.db_password, + db=MysqlDataBaseConfig.db_database ) async def get_db(): # 创建数据库客户端实例 diff --git a/util/use_opanai.py b/util/use_opanai.py index a3559da..4a5b445 100644 --- a/util/use_opanai.py +++ b/util/use_opanai.py @@ -1,14 +1,20 @@ from openai import AsyncOpenAI -client = AsyncOpenAI( - api_key="gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb", - base_url="http://192.168.5.20:4090/v1", +from env import LlmBaseConfig, EmbBaseConfig + +llm_client = AsyncOpenAI( + api_key=LlmBaseConfig.api_key, + base_url=LlmBaseConfig.base_url, ) +emb_client = AsyncOpenAI( + api_key=EmbBaseConfig.api_key, + base_url=EmbBaseConfig.base_url, +) async def generation_rule(prompt): - response = await client.chat.completions.create( - model="qwen3-30b-a3b-instruct-2507", + response = await llm_client.chat.completions.create( + model=LlmBaseConfig.model_name, messages=prompt, n = 1, stream = False, @@ -22,8 +28,8 @@ async def generation_rule(prompt): return response.choices[0].message.content async def generation_vector(text): - response = await client.embeddings.create( - model="bge-m3", # 替换为实际的向量模型名称 + response = await emb_client.embeddings.create( + model=EmbBaseConfig.model_name, # 替换为实际的向量模型名称 input=text, encoding_format="float" ) @@ -52,8 +58,8 @@ async def rerank_documents(query, documents, top_n=None): } ] - response = await client.chat.completions.create( - model="qwen3-30b-a3b-instruct-2507", # 使用已知可用的模型 + response = await llm_client.chat.completions.create( + model=LlmBaseConfig.model_name, # 使用已知可用的模型 messages=messages, temperature=0.0, max_tokens=100 diff --git a/util/use_pgvector.py b/util/use_pgvector.py index f48688e..836a418 100644 --- a/util/use_pgvector.py +++ b/util/use_pgvector.py @@ -2,13 +2,15 @@ import psycopg2 from psycopg2.extras import execute_values import numpy as np +from env import PgvectorDataBaseConfig + # 数据库连接配置 DB_CONFIG = { - "host": "192.168.5.30", - "database": "vectordb", - "user": "myuser", - "password": "mypassword", - "port": 5432 + "host": PgvectorDataBaseConfig.db_host, + "database": PgvectorDataBaseConfig.db_database, + "user": PgvectorDataBaseConfig.db_username, + "password": PgvectorDataBaseConfig.db_password, + "port": PgvectorDataBaseConfig.db_port }