from util import use_pgvector, use_opanai, use_mysql from util.log_util import logger from util.use_mysql import search_desc_by_table_names async def rag_generate_rule(query : str): if not query: return "请输入问题" # 询问大模型抽取用户问题中表名 prompt = [ {"role": "system", "content": f""" 需要从需要从用户问题中抽取表名,请返回一个列表,格式为:[表名1, 表名2, ...]。 例如:用户问题:"请查询表A和表B的交集",则返回:[A, B]。 请根据用户问题,抽取表名,请勿返回其他内容。 """}, {"role": "user", "content": query} ] tables_in_query = await use_opanai.generation_rule(prompt) tables_in_query = tables_in_query.strip('[]') tables_in_query = [item.strip() for item in tables_in_query.split(',')] # 将问题向量化 query_emb = await use_opanai.generation_vector("query") # 连接pgvector数据库 pgvector_conn = use_pgvector.connect_to_db() # 根据问题关联数据库表,得到table_name_list similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4) rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2) table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')] table_name_list.extend(tables_in_query) # print(f"【table_name_list】: {table_name_list}") logger.info(f"【table_name_list】: {table_name_list}") # 获得相关表的schema db_client = await use_mysql.get_db() schema = await search_desc_by_table_names(table_name_list, db_client) # print(f"【schema】: {schema}") logger.info(f"【schema】: {schema}") # 根据问题搜索相关案例 similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3) # print(f"【similar_case】: {similar_case}") logger.info(f"【similar_case】: {similar_case}") # 询问大模型生成SQL prompt = [ {"role": "system", "content": f""" 你是精通postgresql的专家,十分擅长利用用户的问题,和库表schema,生成合适的SQL。最终结果里只需要返回SQL,不要解释。 请根据用户问题,生成SQL,请勿返回其他内容。 schema: {schema} 参考案例: {similar_case} """}, {"role": "user", "content": query} ] # print(f"【prompt】: {prompt}") ans = await use_opanai.generation_rule(prompt) return ans async def main(): ans = await rag_generate_rule("“投资规模结构分解批次(备份)”表BAK_PC_IP_SCALE_STRUCT_RESOLVE中“业务板块类型”不合规的记录") print(f"[answer]: {ans}") # 使用asyncio运行异步函数 import asyncio if __name__ == "__main__": asyncio.run(main())