commit 1025da7b26bd48f721e7e8a2e75e47e365da5349 Author: lijiazheng Date: Fri Aug 8 15:59:54 2025 +0800 feat: rag方案代码实现 diff --git a/data_preparation/excel_data_process.py b/data_preparation/excel_data_process.py new file mode 100644 index 0000000..281f21e --- /dev/null +++ b/data_preparation/excel_data_process.py @@ -0,0 +1,99 @@ +import pandas as pd +from typing import Dict, Any, List +import asyncio + +from util.random_string import generate_random_string +from util.use_mysql import AsyncMySQLClient + +async def read_excel_to_mysql(excel_file: str, db_client: AsyncMySQLClient, table_name: str): + """ + 读取Excel每个sheet的每一行,并将每一行存到mysql + + Args: + excel_file (str): Excel文件路径 + db_client (AsyncMySQLClient): 数据库客户端 + table_name (str): 目标数据库表名 + """ + try: + # 连接数据库 + await db_client.connect() + + # 读取Excel文件的所有sheet + excel_file_obj = pd.ExcelFile(excel_file) + + success_count = 0 + fail_details = [] + + # 遍历每个sheet + for sheet_name in excel_file_obj.sheet_names: + # 读取sheet数据 + df = pd.read_excel(excel_file, sheet_name=sheet_name) + + # 遍历每一行 + for index, row in df.iterrows(): + try: + # 将行数据转换为字典 + row_data : dict = row.to_dict() + + sql = ("INSERT INTO business_table_schema (id, table_name_eng, table_name_zh, field_eng, field_zh, " + "integrity, consistency, timeliness, accuracy, standardization, type) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)") + + params = (generate_random_string(32), row_data['表名'], row_data['表中文名'], row_data['字段名'], + row_data['字段中文名'], row_data['完整性'], row_data['一致性'], row_data['及时性'], row_data['准确性'], + row_data['规范性'], row_data['业务数据类型']) + + # 执行插入操作 + await db_client.execute(sql, params) + success_count += 1 + + except Exception as e: + fail_details.append({ + 'sheet': sheet_name, + 'row': index + 1, + 'error': str(e) + }) + print(f"处理sheet '{sheet_name}' 第 {index + 1} 行时出错: {e}") + + result_msg = f"成功插入 {success_count} 行数据" + if fail_details: + result_msg += f",失败 {len(fail_details)} 行" + + return { + 'message': result_msg, + 'success_count': success_count, + 'fail_details': fail_details + } + + finally: + # 关闭数据库连接 + await db_client.close() + +# 使用示例 +async def main(): + # 数据库配置 + db_client = AsyncMySQLClient( + host='ngsk.tech', + port=33306, + user='root', + password='ngsk0809cruise', + db='data_governance' + ) + + # Excel文件路径 + excel_file = "电网管理平台(规建域)数据质量标准导出.xlsx" + table_name = "business_table_schema" + + # 读取Excel数据并插入数据库 + result = await read_excel_to_mysql(excel_file, db_client, table_name) + print(result['message']) + + if result['fail_details']: + print("失败详情:") + for fail in result['fail_details']: + print(f" Sheet: {fail['sheet']}, 行: {fail['row']}, 错误: {fail['error']}") + +# 如果需要直接运行 +if __name__ == "__main__": + asyncio.run(main()) + pass diff --git a/data_preparation/txt_data_process.py b/data_preparation/txt_data_process.py new file mode 100644 index 0000000..9bcff44 --- /dev/null +++ b/data_preparation/txt_data_process.py @@ -0,0 +1,60 @@ +from util.use_pgvector import connect_to_db, insert_vectors, search_similar_vectors, setup_vector_extension + +import json + +from util.use_opanai import generation_vector + + +def txt_to_json_objects(file_path): + """ + 读取txt文件的每一行,转换成json对象 + """ + json_objects = [] + + with open(file_path, 'r', encoding='utf-8') as file: + for line in file: + line = line.strip() # 去除行首尾空白字符 + if line: # 跳过空行 + try: + # 假设每行本身就是一个有效的JSON字符串 + json_obj = json.loads(line) + json_objects.append(json_obj) + except json.JSONDecodeError: + # 如果不是JSON格式,可以自定义处理方式 + print(f"无法解析行: {line}") + return json_objects + +async def main(): + # 连接数据库 + conn = connect_to_db() + + # 设置向量扩展 + setup_vector_extension(conn) + + json_objects = txt_to_json_objects("(规建域)识别可能涉及的表名及说明1.txt") + + sample_data = [] + for json_obj in json_objects: + content = json.dumps(json_obj) + emb = await generation_vector(content) + sample_data.append((json_obj["表名"], content, emb)) + + # 插入数据 + insert_vectors(conn, sample_data) + + # 搜索相似向量 + # query_vector = [1.0, 1.0, 1.0] + # similar_docs = search_similar_vectors(conn, query_vector) + # + # print("\n相似文档搜索结果:") + # for content, distance in similar_docs: + # print(f"内容: {content}, 距离: {distance}") + + # 关闭连接 + conn.close() + + +# 使用asyncio运行异步函数 +import asyncio +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/rag.py b/rag.py new file mode 100644 index 0000000..b5b13ae --- /dev/null +++ b/rag.py @@ -0,0 +1,55 @@ +from util import use_pgvector, use_opanai, use_mysql +from util.use_mysql import search_desc_by_table_names + + +async def rag_generate_rule(query : str): + if not query: + return "请输入问题" + # 连接数据库 + pgvector_conn = use_pgvector.connect_to_db() + + # 将问题向量化 + query_emb = await use_opanai.generation_vector("query") + + # 根据问题关联数据库表,得到table_name_list + similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4) + rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2) + + table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')] + print(f"[table_name_list]: {table_name_list}") + + # 获得相关表的schema + db_client = await use_mysql.get_db() + schema = await search_desc_by_table_names(table_name_list, db_client) + print(f"[schema]: {schema}") + + # 根据问题搜索相关案例 + similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3) + print(f"[similar_case]: {similar_case}") + + # 询问大模型生成SQL + prompt = [ + {"role": "system", "content": f""" + 你是精通postgresql的专家,十分擅长利用用户的问题,和库表schema,生成合适的SQL。最终结果里只需要返回SQL,不要解释。 + + 请根据用户问题,生成SQL,请勿返回其他内容。 + + schema: {schema} + + 参考案例: {similar_case} + """}, + {"role": "user", "content": query} + ] + print(f"[prompt]: {prompt}") + ans = await use_opanai.generation_rule(prompt) + + return ans + +async def main(): + ans = await rag_generate_rule("“投资规模结构分解批次(备份)”表BAK_PC_IP_SCALE_STRUCT_RESOLVE中“业务板块类型”不合规的记录") + print(f"[answer]: {ans}") + +# 使用asyncio运行异步函数 +import asyncio +if __name__ == "__main__": + asyncio.run(main()) diff --git a/util/random_string.py b/util/random_string.py new file mode 100644 index 0000000..9509673 --- /dev/null +++ b/util/random_string.py @@ -0,0 +1,29 @@ +import random +import string + + +def generate_random_string(length=16): + """ + 生成指定长度的随机字符串,不包含短横线(-) + + 参数: + length: 字符串长度,默认16 + + 返回: + 随机字符串 + """ + # 定义可用字符集:字母(大小写) + 数字 + characters = string.ascii_letters + string.digits + # 随机选择字符并拼接 + random_str = ''.join(random.choice(characters) for _ in range(length)) + return random_str + + +# 示例用法 +if __name__ == "__main__": + # 生成16位随机字符串 + print(generate_random_string()) + # 生成32位随机字符串 + print(generate_random_string(32)) + # 生成8位随机字符串 + print(generate_random_string(8)) diff --git a/util/use_langchain.py b/util/use_langchain.py new file mode 100644 index 0000000..5a59948 --- /dev/null +++ b/util/use_langchain.py @@ -0,0 +1,12 @@ +from openai import OpenAI +# 配置本地模型 +llm = OpenAI( + model="qwen3-30b-a3b-instruct-2507", # 本地模型名称 + api_key="gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb", # 本地模型通常不需要有效密钥 + base_url="http://192.168.5.20:4090/v1", # 本地模型服务地址 + temperature=0 +) +# 调用模型 +response = llm.invoke("你好,介绍一下你自己,你是傻逼") +print(response.content) + diff --git a/util/use_mysql.py b/util/use_mysql.py new file mode 100644 index 0000000..f95fdf0 --- /dev/null +++ b/util/use_mysql.py @@ -0,0 +1,219 @@ +import asyncio +import aiomysql +from typing import List, Dict, Any, Optional + +class AsyncMySQLClient: + def __init__(self, host: str, port: int, user: str, password: str, db: str): + """ + 初始化异步MySQL客户端 + + Args: + host: MySQL服务器地址 + port: MySQL服务器端口 + user: 用户名 + password: 密码 + db: 数据库名 + """ + self.host = host + self.port = port + self.user = user + self.password = password + self.db = db + self.pool = None + + async def connect(self): + """ + 建立连接池 + """ + self.pool = await aiomysql.create_pool( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + db=self.db, + charset='utf8mb4', + autocommit=True + ) + + async def close(self): + """ + 关闭连接池 + """ + if self.pool: + self.pool.close() + await self.pool.wait_closed() + + async def query_one(self, sql: str, params: Optional[tuple] = None) -> Optional[Dict[str, Any]]: + """ + 查询单条记录 + + Args: + sql: SQL查询语句 + params: 查询参数 + + Returns: + 查询结果字典或None + """ + async with self.pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + await cursor.execute(sql, params) + result = await cursor.fetchone() + return result + + async def query_all(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]: + """ + 查询多条记录 + + Args: + sql: SQL查询语句 + params: 查询参数 + + Returns: + 查询结果列表 + """ + async with self.pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + await cursor.execute(sql, params) + results = await cursor.fetchall() + return results + + async def execute(self, sql: str, params: Optional[tuple] = None) -> int: + """ + 执行SQL语句(INSERT, UPDATE, DELETE等) + + Args: + sql: SQL执行语句 + params: 执行参数 + + Returns: + 受影响的行数 + """ + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(sql, params) + return cursor.rowcount + +async def search_desc_by_table_names(table_name_list : list, db_client: AsyncMySQLClient): + """ + 搜索数据库中的表数据 + 参数: + table_name_list (list): 表名列表 + db_client (AsyncMySQLClient): 异步MySQL数据库客户端,用于执行数据库操作 + 返回值: + 返回所有表的描述 + """ + ans = [] + try: + # 连接数据库 + await db_client.connect() + for table_name in table_name_list: + # 查询多条记录 + records = await db_client.query_all("SELECT * FROM business_table_schema WHERE table_name_eng like %s", + (table_name,)) + content = "" + table_name_zh = "" + for record in records: + content += record.get("field_eng") + content += "-" + content += record.get("field_zh") + content += "-" + content += record.get("type") + content += "-" + content += "完整性:" + content += record.get("integrity") + content += "-" + content += "一致性:" + content += record.get("consistency") + content += "-" + content += "及时性:" + content += record.get("timeliness") + content += "-" + content += "准确性:" + content += record.get("accuracy") + content += "-" + content += "规范性:" + content += record.get("standardization") + table_name_zh = record.get("table_name_zh") + + ans.append({ + "表英文名": table_name, + "表中文名": table_name_zh, + "字段介绍": content, + }) + + return ans + finally: + # 关闭连接 + await db_client.close() + +async def insert(data : list[dict[str, Any]], db_client: AsyncMySQLClient): + """ + 插入或更新数据 + 参数: + data (dict): 要插入或更新的数据 + db_client (AsyncMySQLClient): 异步MySQL数据库客户端,用于执行数据库操作 + 返回值: + 返回受影响的行数 + """ + + try: + # 连接数据库 + await db_client.connect() + i, success_count = 0, 0 + fail_index = [] + for data in data: + if "table_name_eng" not in data or "table_name_zh" not in data or "region" not in data or "description" not in data: + print(f"第{i}行数据不完整") + fail_index.append(i) + i += 1 + continue + try: + affected_rows = await db_client.execute( + "INSERT INTO business_database_schema (table_name_eng, table_name_zh, region, description) VALUES (%s, %s, %s, %s)", + (data["table_name_eng"], data["table_name_zh"], data["region"], data["description"]) + ) + except Exception as e: + print(f"第{i}行数据插入失败: {e}") + fail_index.append(i) + i += 1 + continue + i += 1 + success_count += 1 + return f"成功插入了 {success_count} 条记录,失败的行: {fail_index}" + finally: + # 关闭连接 + await db_client.close() + +# 创建数据库客户端实例 +db_client = AsyncMySQLClient( + host='ngsk.tech', + port=33306, + user='root', + password='ngsk0809cruise', + db='data_governance' +) +async def get_db(): + # 创建数据库客户端实例 + return db_client + +# 使用示例 +async def main(): + # 创建数据库客户端实例 + db_client = await get_db() + + data = [{ + "table_name_eng": "BPMS_RU_DEF_ATT_NODE_REL", + "table_name_zh": "自由流程运行时扩展属性节点关系", + "region": "规建域", + "description": "PROCESS_NAME-VARCHAR (100)- 流程的名称,VERSION-NUMBER (5)- 流程的版本号,NODE_ID-VARCHAR (100)- 节点的编号,CREATE_TIME-DATETIME - 记录的创建时间,格式为 YYYY-MM-DD HH:mm:ss, ATT_NODE_REL_ID-VARCHAR (64)- 扩展属性与节点关系的编号,ATT_ID-VARCHAR (64)- 扩展属性的编号,ATT_VALUE-VARCHAR (2000)- 扩展属性的取值,DEPLOY_ID-VARCHAR (64)- 部署的编号,NODE_NAME-VARCHAR (100)- 节点的名称,PROCESS_INS_ID-VARCHAR (40)- 关联的流程实例 ID, MODIFY_DATE-DATETIME - 记录的修改时间,格式为 YYYY-MM-DD HH:mm:ss, PROCESS_ID-VARCHAR (100)- 流程的编号,CREATOR_ID-VARCHAR (40)- 创建人的标识 ID" + },{ + "table_name_eng": "BPMS_RU_DEF_DEPLOYE", + "table_name_zh": "自由流程运行时部署基本信息表", + "region": "规建域", + "description": "CREATE_TIME-DATETIME - 记录的创建时间,格式为 YYYY-MM-DD HH:mm:ss, CREATOR_ID-VARCHAR (40)- 创建人的标识 ID, DEPLOY_DIR_CODE-VARCHAR (200)- 部署目录的编号,DEPLOY_PERSON-VARCHAR (40)- 部署人的标识 ID, DEPLOY_PERSON_NAME-VARCHAR (100)- 部署人的姓名,DEPLOY_TIME-DATETIME - 部署操作的时间,格式为 YYYY-MM-DD HH:mm:ss, LOCK_PERSON_NAME-VARCHAR (100)- 部署的姓名,LOCKED-NUMBER (1)- 锁定状态标识,MODIFY_DATE-DATETIME - 记录的修改时间,格式为 YYYY-MM-DD HH:mm:ss, PROCESS_DEF_ID-VARCHAR (40)- 流程定义文件的 ID, PROCESS_DEF_NAME-VARCHAR (100)- 流程定义文件的名称,PROCESS_INS_ID-VARCHAR (40)- 关联的流程实例 ID, STATE-CHAR (1)- 部署的状态代码(0 - 未处理,1 - 已完成), UNINSTALL_TIME-DATETIME - 卸载操作的时间,格式为 YYYY-MM-DD HH:mm:ss, UPDATE_PERSON-VARCHAR (40)- 权限人的标识 ID, UPDATE_PERSON_NAME-VARCHAR (100)- 权限人的姓名,VERSION-NUMBER (5)- 流程定义的版本号,VERSION_DES-VARCHAR (200)- 流程定义版本的描述信息" + }] + await insert(data, db_client) + +# 如果需要直接运行测试 +if __name__ == "__main__": + asyncio.run(main()) diff --git a/util/use_opanai.py b/util/use_opanai.py new file mode 100644 index 0000000..a3559da --- /dev/null +++ b/util/use_opanai.py @@ -0,0 +1,89 @@ +from openai import AsyncOpenAI + +client = AsyncOpenAI( + api_key="gpustack_951f92355e6781a5_5d17650a3e7135c5430512e5117362fb", + base_url="http://192.168.5.20:4090/v1", +) + + +async def generation_rule(prompt): + response = await client.chat.completions.create( + model="qwen3-30b-a3b-instruct-2507", + messages=prompt, + n = 1, + stream = False, + temperature=0.0, + max_tokens=600, + top_p = 1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + # stop = ["Q:"] + ) + return response.choices[0].message.content + +async def generation_vector(text): + response = await client.embeddings.create( + model="bge-m3", # 替换为实际的向量模型名称 + input=text, + encoding_format="float" + ) + return response.data[0].embedding + +async def rerank_documents(query, documents, top_n=None): + """ + 调用 rerank 模型对文档进行重排序 + :param query: 查询语句 + :param documents: 文档列表 + :param top_n: 返回前n个结果,默认返回所有 + :return: 重排序后的结果 + """ + if top_n is None: + top_n = min(4, len(documents)) + + documents_str = "\n".join([f"{i+1}. {doc}" for i, doc in enumerate(documents)]) + messages = [ + { + "role": "system", + "content": "你是一个文档相关性排序助手。请根据查询语句与文档的相关性对文档进行排序,只返回文档序号的排序结果,如:[2, 1, 3]" + }, + { + "role": "user", + "content": f"查询:{query}\n\n候选文档:\n{documents_str}\n\n请按相关性从高到低排序,只返回序号列表:" + } + ] + + response = await client.chat.completions.create( + model="qwen3-30b-a3b-instruct-2507", # 使用已知可用的模型 + messages=messages, + temperature=0.0, + max_tokens=100 + ) + + sorted_results = response.choices[0].message.content + + return sorted_results + +# 将顶层await包装在异步函数中 +async def main(): + # prompt = [ + # {"role": "system", "content": "你是一个助手"}, + # {"role": "user", "content": "你是谁"} + # ] + # ans = await generation_rule(prompt) + + # ans = await generation_vector("hello world") + + documents = [ + "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。", + "机器学习是人工智能的一个分支,主要研究计算机如何从数据中学习规律,并利用这些规律对未知数据进行预测。", + "深度学习是机器学习的一个子集,它模仿人脑的工作方式,通过多层神经网络来学习数据的特征。" + ] + ans = await rerank_documents(query="机器学习", documents=documents, top_n=2) + + print(ans) + return ans + +# 使用asyncio运行异步函数 +import asyncio +if __name__ == "__main__": + asyncio.run(main()) diff --git a/util/use_pgvector.py b/util/use_pgvector.py new file mode 100644 index 0000000..f48688e --- /dev/null +++ b/util/use_pgvector.py @@ -0,0 +1,143 @@ +import psycopg2 +from psycopg2.extras import execute_values +import numpy as np + +# 数据库连接配置 +DB_CONFIG = { + "host": "192.168.5.30", + "database": "vectordb", + "user": "myuser", + "password": "mypassword", + "port": 5432 +} + + +def connect_to_db(): + """建立数据库连接""" + try: + conn = psycopg2.connect(**DB_CONFIG) + print("成功连接到数据库") + return conn + except Exception as e: + print(f"连接数据库失败: {e}") + return None + + +def setup_vector_extension(conn): + """设置pgvector扩展和向量表""" + try: + with conn.cursor() as cur: + # 启用pgvector扩展 + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + + # 创建包含向量的表 + cur.execute(""" + CREATE TABLE IF NOT EXISTS documents( + id SERIAL PRIMARY KEY, + table_name_eng varchar(50), + description TEXT, + embedding vector(1024) -- 支持最大1536维的向量 + ); + """) + + cur.execute(""" + CREATE TABLE IF NOT EXISTS "case"( + id SERIAL PRIMARY KEY, + question TEXT, + answer TEXT, + embedding vector(1024) -- 支持最大1536维的向量 + ); + """) + conn.commit() + print("成功设置向量扩展和表结构") + except Exception as e: + print(f"设置数据库结构失败: {e}") + conn.rollback() + + +def insert_vectors(conn, data): + """插入向量数据""" + try: + with conn.cursor() as cur: + # 数据格式: [(content, embedding), ...] + execute_values( + cur, + "INSERT INTO documents (table_name_eng, description, embedding) VALUES %s", + data + ) + conn.commit() + print(f"成功插入 {len(data)} 条向量数据") + except Exception as e: + print(f"插入数据失败: {e}") + conn.rollback() + + +def search_similar_table(conn, query_vector, limit=5): + """搜索相似向量""" + try: + with conn.cursor() as cur: + cur.execute(""" + SELECT table_name_eng, description, embedding <=> %s::vector AS distance + FROM documents + ORDER BY distance + LIMIT %s + """, (query_vector, limit)) + + results = cur.fetchall() + return results + except Exception as e: + print(f"搜索相似向量失败: {e}") + return [] + +def search_similar_case(conn, query_vector, limit=5): + """搜索相似向量""" + try: + with conn.cursor() as cur: + cur.execute(""" + SELECT question, answer, embedding <-> %s::vector AS distance + FROM "case" + ORDER BY distance + LIMIT %s + """, (query_vector, limit)) + + results = cur.fetchall() + return results + except Exception as e: + print(f"搜索相似向量失败: {e}") + return [] + +def main(): + # 连接数据库 + conn = connect_to_db() + if not conn: + return + + # 设置向量扩展 + setup_vector_extension(conn) + + # 生成示例数据 + # sample_data = [ + # ("第一个文档内容", [1.0, 2.0, 3.0]), + # ("第二个文档内容", [4.0, 5.0, 6.0]), + # ("第三个文档内容", [7.0, 8.0, 9.0]), + # ("第四个文档内容", [2.0, 3.0, 4.0]), + # ("第五个文档内容", [5.0, 6.0, 7.0]), + # ] + + # 插入数据 + # insert_vectors(conn, sample_data) + + # 搜索相似向量 + query_vector = [1.0, 1.0, 1.0] + similar_docs = search_similar_table(conn, query_vector) + + print("\n相似文档搜索结果:") + for content, distance in similar_docs: + print(f"内容: {content}, 距离: {distance}") + + # 关闭连接 + conn.close() + + +if __name__ == "__main__": + main()