feat: rag方案代码实现

main
lijiazheng 6 months ago
commit 1025da7b26

@ -0,0 +1,99 @@
import pandas as pd
from typing import Dict, Any, List
import asyncio
from util.random_string import generate_random_string
from util.use_mysql import AsyncMySQLClient
async def read_excel_to_mysql(excel_file: str, db_client: AsyncMySQLClient, table_name: str):
"""
读取Excel每个sheet的每一行并将每一行存到mysql
Args:
excel_file (str): Excel文件路径
db_client (AsyncMySQLClient): 数据库客户端
table_name (str): 目标数据库表名
"""
try:
# 连接数据库
await db_client.connect()
# 读取Excel文件的所有sheet
excel_file_obj = pd.ExcelFile(excel_file)
success_count = 0
fail_details = []
# 遍历每个sheet
for sheet_name in excel_file_obj.sheet_names:
# 读取sheet数据
df = pd.read_excel(excel_file, sheet_name=sheet_name)
# 遍历每一行
for index, row in df.iterrows():
try:
# 将行数据转换为字典
row_data : dict = row.to_dict()
sql = ("INSERT INTO business_table_schema (id, table_name_eng, table_name_zh, field_eng, field_zh, "
"integrity, consistency, timeliness, accuracy, standardization, type) "
"VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)")
params = (generate_random_string(32), row_data['表名'], row_data['表中文名'], row_data['字段名'],
row_data['字段中文名'], row_data['完整性'], row_data['一致性'], row_data['及时性'], row_data['准确性'],
row_data['规范性'], row_data['业务数据类型'])
# 执行插入操作
await db_client.execute(sql, params)
success_count += 1
except Exception as e:
fail_details.append({
'sheet': sheet_name,
'row': index + 1,
'error': str(e)
})
print(f"处理sheet '{sheet_name}'{index + 1} 行时出错: {e}")
result_msg = f"成功插入 {success_count} 行数据"
if fail_details:
result_msg += f",失败 {len(fail_details)}"
return {
'message': result_msg,
'success_count': success_count,
'fail_details': fail_details
}
finally:
# 关闭数据库连接
await db_client.close()
# 使用示例
async def main():
# 数据库配置
db_client = AsyncMySQLClient(
host='ngsk.tech',
port=33306,
user='root',
password='ngsk0809cruise',
db='data_governance'
)
# Excel文件路径
excel_file = "电网管理平台(规建域)数据质量标准导出.xlsx"
table_name = "business_table_schema"
# 读取Excel数据并插入数据库
result = await read_excel_to_mysql(excel_file, db_client, table_name)
print(result['message'])
if result['fail_details']:
print("失败详情:")
for fail in result['fail_details']:
print(f" Sheet: {fail['sheet']}, 行: {fail['row']}, 错误: {fail['error']}")
# 如果需要直接运行
if __name__ == "__main__":
asyncio.run(main())
pass

@ -0,0 +1,60 @@
from util.use_pgvector import connect_to_db, insert_vectors, search_similar_vectors, setup_vector_extension
import json
from util.use_opanai import generation_vector
def txt_to_json_objects(file_path):
"""
读取txt文件的每一行转换成json对象
"""
json_objects = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip() # 去除行首尾空白字符
if line: # 跳过空行
try:
# 假设每行本身就是一个有效的JSON字符串
json_obj = json.loads(line)
json_objects.append(json_obj)
except json.JSONDecodeError:
# 如果不是JSON格式可以自定义处理方式
print(f"无法解析行: {line}")
return json_objects
async def main():
# 连接数据库
conn = connect_to_db()
# 设置向量扩展
setup_vector_extension(conn)
json_objects = txt_to_json_objects("规建域识别可能涉及的表名及说明1.txt")
sample_data = []
for json_obj in json_objects:
content = json.dumps(json_obj)
emb = await generation_vector(content)
sample_data.append((json_obj["表名"], content, emb))
# 插入数据
insert_vectors(conn, sample_data)
# 搜索相似向量
# query_vector = [1.0, 1.0, 1.0]
# similar_docs = search_similar_vectors(conn, query_vector)
#
# print("\n相似文档搜索结果:")
# for content, distance in similar_docs:
# print(f"内容: {content}, 距离: {distance}")
# 关闭连接
conn.close()
# 使用asyncio运行异步函数
import asyncio
if __name__ == "__main__":
asyncio.run(main())

@ -0,0 +1,55 @@
from util import use_pgvector, use_opanai, use_mysql
from util.use_mysql import search_desc_by_table_names
async def rag_generate_rule(query : str):
if not query:
return "请输入问题"
# 连接数据库
pgvector_conn = use_pgvector.connect_to_db()
# 将问题向量化
query_emb = await use_opanai.generation_vector("query")
# 根据问题关联数据库表得到table_name_list
similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4)
rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2)
table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')]
print(f"[table_name_list]: {table_name_list}")
# 获得相关表的schema
db_client = await use_mysql.get_db()
schema = await search_desc_by_table_names(table_name_list, db_client)
print(f"[schema]: {schema}")
# 根据问题搜索相关案例
similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3)
print(f"[similar_case]: {similar_case}")
# 询问大模型生成SQL
prompt = [
{"role": "system", "content": f"""
你是精通postgresql的专家十分擅长利用用户的问题和库表schema生成合适的SQL最终结果里只需要返回SQL不要解释
请根据用户问题生成SQL请勿返回其他内容
schema: {schema}
参考案例: {similar_case}
"""},
{"role": "user", "content": query}
]
print(f"[prompt]: {prompt}")
ans = await use_opanai.generation_rule(prompt)
return ans
async def main():
ans = await rag_generate_rule("“投资规模结构分解批次备份”表BAK_PC_IP_SCALE_STRUCT_RESOLVE中“业务板块类型”不合规的记录")
print(f"[answer]: {ans}")
# 使用asyncio运行异步函数
import asyncio
if __name__ == "__main__":
asyncio.run(main())

@ -0,0 +1,29 @@
import random
import string
def generate_random_string(length=16):
"""
生成指定长度的随机字符串不包含短横线(-)
参数:
length: 字符串长度默认16
返回:
随机字符串
"""
# 定义可用字符集:字母(大小写) + 数字
characters = string.ascii_letters + string.digits
# 随机选择字符并拼接
random_str = ''.join(random.choice(characters) for _ in range(length))
return random_str
# 示例用法
if __name__ == "__main__":
# 生成16位随机字符串
print(generate_random_string())
# 生成32位随机字符串
print(generate_random_string(32))
# 生成8位随机字符串
print(generate_random_string(8))

@ -0,0 +1,12 @@
from openai import OpenAI
# 配置本地模型
llm = OpenAI(
model="qwen3-30b-a3b-instruct-2507", # 本地模型名称
api_key="gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb", # 本地模型通常不需要有效密钥
base_url="http://192.168.5.20:4090/v1", # 本地模型服务地址
temperature=0
)
# 调用模型
response = llm.invoke("你好,介绍一下你自己,你是傻逼")
print(response.content)

@ -0,0 +1,219 @@
import asyncio
import aiomysql
from typing import List, Dict, Any, Optional
class AsyncMySQLClient:
def __init__(self, host: str, port: int, user: str, password: str, db: str):
"""
初始化异步MySQL客户端
Args:
host: MySQL服务器地址
port: MySQL服务器端口
user: 用户名
password: 密码
db: 数据库名
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.db = db
self.pool = None
async def connect(self):
"""
建立连接池
"""
self.pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.db,
charset='utf8mb4',
autocommit=True
)
async def close(self):
"""
关闭连接池
"""
if self.pool:
self.pool.close()
await self.pool.wait_closed()
async def query_one(self, sql: str, params: Optional[tuple] = None) -> Optional[Dict[str, Any]]:
"""
查询单条记录
Args:
sql: SQL查询语句
params: 查询参数
Returns:
查询结果字典或None
"""
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql, params)
result = await cursor.fetchone()
return result
async def query_all(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
"""
查询多条记录
Args:
sql: SQL查询语句
params: 查询参数
Returns:
查询结果列表
"""
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql, params)
results = await cursor.fetchall()
return results
async def execute(self, sql: str, params: Optional[tuple] = None) -> int:
"""
执行SQL语句INSERT, UPDATE, DELETE等
Args:
sql: SQL执行语句
params: 执行参数
Returns:
受影响的行数
"""
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(sql, params)
return cursor.rowcount
async def search_desc_by_table_names(table_name_list : list, db_client: AsyncMySQLClient):
"""
搜索数据库中的表数据
参数:
table_name_list (list): 表名列表
db_client (AsyncMySQLClient): 异步MySQL数据库客户端用于执行数据库操作
返回值:
返回所有表的描述
"""
ans = []
try:
# 连接数据库
await db_client.connect()
for table_name in table_name_list:
# 查询多条记录
records = await db_client.query_all("SELECT * FROM business_table_schema WHERE table_name_eng like %s",
(table_name,))
content = ""
table_name_zh = ""
for record in records:
content += record.get("field_eng")
content += "-"
content += record.get("field_zh")
content += "-"
content += record.get("type")
content += "-"
content += "完整性:"
content += record.get("integrity")
content += "-"
content += "一致性:"
content += record.get("consistency")
content += "-"
content += "及时性:"
content += record.get("timeliness")
content += "-"
content += "准确性:"
content += record.get("accuracy")
content += "-"
content += "规范性:"
content += record.get("standardization")
table_name_zh = record.get("table_name_zh")
ans.append({
"表英文名": table_name,
"表中文名": table_name_zh,
"字段介绍": content,
})
return ans
finally:
# 关闭连接
await db_client.close()
async def insert(data : list[dict[str, Any]], db_client: AsyncMySQLClient):
"""
插入或更新数据
参数:
data (dict): 要插入或更新的数据
db_client (AsyncMySQLClient): 异步MySQL数据库客户端用于执行数据库操作
返回值:
返回受影响的行数
"""
try:
# 连接数据库
await db_client.connect()
i, success_count = 0, 0
fail_index = []
for data in data:
if "table_name_eng" not in data or "table_name_zh" not in data or "region" not in data or "description" not in data:
print(f"{i}行数据不完整")
fail_index.append(i)
i += 1
continue
try:
affected_rows = await db_client.execute(
"INSERT INTO business_database_schema (table_name_eng, table_name_zh, region, description) VALUES (%s, %s, %s, %s)",
(data["table_name_eng"], data["table_name_zh"], data["region"], data["description"])
)
except Exception as e:
print(f"{i}行数据插入失败: {e}")
fail_index.append(i)
i += 1
continue
i += 1
success_count += 1
return f"成功插入了 {success_count} 条记录,失败的行: {fail_index}"
finally:
# 关闭连接
await db_client.close()
# 创建数据库客户端实例
db_client = AsyncMySQLClient(
host='ngsk.tech',
port=33306,
user='root',
password='ngsk0809cruise',
db='data_governance'
)
async def get_db():
# 创建数据库客户端实例
return db_client
# 使用示例
async def main():
# 创建数据库客户端实例
db_client = await get_db()
data = [{
"table_name_eng": "BPMS_RU_DEF_ATT_NODE_REL",
"table_name_zh": "自由流程运行时扩展属性节点关系",
"region": "规建域",
"description": "PROCESS_NAME-VARCHAR (100)- 流程的名称VERSION-NUMBER (5)- 流程的版本号NODE_ID-VARCHAR (100)- 节点的编号CREATE_TIME-DATETIME - 记录的创建时间,格式为 YYYY-MM-DD HH:mm:ss, ATT_NODE_REL_ID-VARCHAR (64)- 扩展属性与节点关系的编号ATT_ID-VARCHAR (64)- 扩展属性的编号ATT_VALUE-VARCHAR (2000)- 扩展属性的取值DEPLOY_ID-VARCHAR (64)- 部署的编号NODE_NAME-VARCHAR (100)- 节点的名称PROCESS_INS_ID-VARCHAR (40)- 关联的流程实例 ID, MODIFY_DATE-DATETIME - 记录的修改时间,格式为 YYYY-MM-DD HH:mm:ss, PROCESS_ID-VARCHAR (100)- 流程的编号CREATOR_ID-VARCHAR (40)- 创建人的标识 ID"
},{
"table_name_eng": "BPMS_RU_DEF_DEPLOYE",
"table_name_zh": "自由流程运行时部署基本信息表",
"region": "规建域",
"description": "CREATE_TIME-DATETIME - 记录的创建时间,格式为 YYYY-MM-DD HH:mm:ss, CREATOR_ID-VARCHAR (40)- 创建人的标识 ID, DEPLOY_DIR_CODE-VARCHAR (200)- 部署目录的编号DEPLOY_PERSON-VARCHAR (40)- 部署人的标识 ID, DEPLOY_PERSON_NAME-VARCHAR (100)- 部署人的姓名DEPLOY_TIME-DATETIME - 部署操作的时间,格式为 YYYY-MM-DD HH:mm:ss, LOCK_PERSON_NAME-VARCHAR (100)- 部署的姓名LOCKED-NUMBER (1)- 锁定状态标识MODIFY_DATE-DATETIME - 记录的修改时间,格式为 YYYY-MM-DD HH:mm:ss, PROCESS_DEF_ID-VARCHAR (40)- 流程定义文件的 ID, PROCESS_DEF_NAME-VARCHAR (100)- 流程定义文件的名称PROCESS_INS_ID-VARCHAR (40)- 关联的流程实例 ID, STATE-CHAR (1)- 部署的状态代码0 - 未处理1 - 已完成), UNINSTALL_TIME-DATETIME - 卸载操作的时间,格式为 YYYY-MM-DD HH:mm:ss, UPDATE_PERSON-VARCHAR (40)- 权限人的标识 ID, UPDATE_PERSON_NAME-VARCHAR (100)- 权限人的姓名VERSION-NUMBER (5)- 流程定义的版本号VERSION_DES-VARCHAR (200)- 流程定义版本的描述信息"
}]
await insert(data, db_client)
# 如果需要直接运行测试
if __name__ == "__main__":
asyncio.run(main())

@ -0,0 +1,89 @@
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key="gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb",
base_url="http://192.168.5.20:4090/v1",
)
async def generation_rule(prompt):
response = await client.chat.completions.create(
model="qwen3-30b-a3b-instruct-2507",
messages=prompt,
n = 1,
stream = False,
temperature=0.0,
max_tokens=600,
top_p = 1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
# stop = ["Q:"]
)
return response.choices[0].message.content
async def generation_vector(text):
response = await client.embeddings.create(
model="bge-m3", # 替换为实际的向量模型名称
input=text,
encoding_format="float"
)
return response.data[0].embedding
async def rerank_documents(query, documents, top_n=None):
"""
调用 rerank 模型对文档进行重排序
:param query: 查询语句
:param documents: 文档列表
:param top_n: 返回前n个结果默认返回所有
:return: 重排序后的结果
"""
if top_n is None:
top_n = min(4, len(documents))
documents_str = "\n".join([f"{i+1}. {doc}" for i, doc in enumerate(documents)])
messages = [
{
"role": "system",
"content": "你是一个文档相关性排序助手。请根据查询语句与文档的相关性对文档进行排序,只返回文档序号的排序结果,如:[2, 1, 3]"
},
{
"role": "user",
"content": f"查询:{query}\n\n候选文档:\n{documents_str}\n\n请按相关性从高到低排序,只返回序号列表:"
}
]
response = await client.chat.completions.create(
model="qwen3-30b-a3b-instruct-2507", # 使用已知可用的模型
messages=messages,
temperature=0.0,
max_tokens=100
)
sorted_results = response.choices[0].message.content
return sorted_results
# 将顶层await包装在异步函数中
async def main():
# prompt = [
# {"role": "system", "content": "你是一个助手"},
# {"role": "user", "content": "你是谁"}
# ]
# ans = await generation_rule(prompt)
# ans = await generation_vector("hello world")
documents = [
"人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"机器学习是人工智能的一个分支,主要研究计算机如何从数据中学习规律,并利用这些规律对未知数据进行预测。",
"深度学习是机器学习的一个子集,它模仿人脑的工作方式,通过多层神经网络来学习数据的特征。"
]
ans = await rerank_documents(query="机器学习", documents=documents, top_n=2)
print(ans)
return ans
# 使用asyncio运行异步函数
import asyncio
if __name__ == "__main__":
asyncio.run(main())

@ -0,0 +1,143 @@
import psycopg2
from psycopg2.extras import execute_values
import numpy as np
# 数据库连接配置
DB_CONFIG = {
"host": "192.168.5.30",
"database": "vectordb",
"user": "myuser",
"password": "mypassword",
"port": 5432
}
def connect_to_db():
"""建立数据库连接"""
try:
conn = psycopg2.connect(**DB_CONFIG)
print("成功连接到数据库")
return conn
except Exception as e:
print(f"连接数据库失败: {e}")
return None
def setup_vector_extension(conn):
"""设置pgvector扩展和向量表"""
try:
with conn.cursor() as cur:
# 启用pgvector扩展
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
# 创建包含向量的表
cur.execute("""
CREATE TABLE IF NOT EXISTS documents(
id SERIAL PRIMARY KEY,
table_name_eng varchar(50),
description TEXT,
embedding vector(1024) -- 支持最大1536维的向量
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS "case"(
id SERIAL PRIMARY KEY,
question TEXT,
answer TEXT,
embedding vector(1024) -- 支持最大1536维的向量
);
""")
conn.commit()
print("成功设置向量扩展和表结构")
except Exception as e:
print(f"设置数据库结构失败: {e}")
conn.rollback()
def insert_vectors(conn, data):
"""插入向量数据"""
try:
with conn.cursor() as cur:
# 数据格式: [(content, embedding), ...]
execute_values(
cur,
"INSERT INTO documents (table_name_eng, description, embedding) VALUES %s",
data
)
conn.commit()
print(f"成功插入 {len(data)} 条向量数据")
except Exception as e:
print(f"插入数据失败: {e}")
conn.rollback()
def search_similar_table(conn, query_vector, limit=5):
"""搜索相似向量"""
try:
with conn.cursor() as cur:
cur.execute("""
SELECT table_name_eng, description, embedding <=> %s::vector AS distance
FROM documents
ORDER BY distance
LIMIT %s
""", (query_vector, limit))
results = cur.fetchall()
return results
except Exception as e:
print(f"搜索相似向量失败: {e}")
return []
def search_similar_case(conn, query_vector, limit=5):
"""搜索相似向量"""
try:
with conn.cursor() as cur:
cur.execute("""
SELECT question, answer, embedding <-> %s::vector AS distance
FROM "case"
ORDER BY distance
LIMIT %s
""", (query_vector, limit))
results = cur.fetchall()
return results
except Exception as e:
print(f"搜索相似向量失败: {e}")
return []
def main():
# 连接数据库
conn = connect_to_db()
if not conn:
return
# 设置向量扩展
setup_vector_extension(conn)
# 生成示例数据
# sample_data = [
# ("第一个文档内容", [1.0, 2.0, 3.0]),
# ("第二个文档内容", [4.0, 5.0, 6.0]),
# ("第三个文档内容", [7.0, 8.0, 9.0]),
# ("第四个文档内容", [2.0, 3.0, 4.0]),
# ("第五个文档内容", [5.0, 6.0, 7.0]),
# ]
# 插入数据
# insert_vectors(conn, sample_data)
# 搜索相似向量
query_vector = [1.0, 1.0, 1.0]
similar_docs = search_similar_table(conn, query_vector)
print("\n相似文档搜索结果:")
for content, distance in similar_docs:
print(f"内容: {content}, 距离: {distance}")
# 关闭连接
conn.close()
if __name__ == "__main__":
main()
Loading…
Cancel
Save