from util import use_pgvector, use_opanai, use_mysql from util.use_mysql import search_desc_by_table_names async def rag_generate_rule(query : str): if not query: return "请输入问题" # 连接数据库 pgvector_conn = use_pgvector.connect_to_db() # 将问题向量化 query_emb = await use_opanai.generation_vector("query") # 根据问题关联数据库表,得到table_name_list similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4) rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2) table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')] print(f"[table_name_list]: {table_name_list}") # 获得相关表的schema db_client = await use_mysql.get_db() schema = await search_desc_by_table_names(table_name_list, db_client) print(f"[schema]: {schema}") # 根据问题搜索相关案例 similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3) print(f"[similar_case]: {similar_case}") # 询问大模型生成SQL prompt = [ {"role": "system", "content": f""" 你是精通postgresql的专家,十分擅长利用用户的问题,和库表schema,生成合适的SQL。最终结果里只需要返回SQL,不要解释。 请根据用户问题,生成SQL,请勿返回其他内容。 schema: {schema} 参考案例: {similar_case} """}, {"role": "user", "content": query} ] print(f"[prompt]: {prompt}") ans = await use_opanai.generation_rule(prompt) return ans async def main(): ans = await rag_generate_rule("“投资规模结构分解批次(备份)”表BAK_PC_IP_SCALE_STRUCT_RESOLVE中“业务板块类型”不合规的记录") print(f"[answer]: {ans}") # 使用asyncio运行异步函数 import asyncio if __name__ == "__main__": asyncio.run(main())