|
|
from util import use_pgvector, use_opanai, use_mysql
|
|
|
from util.use_mysql import search_desc_by_table_names
|
|
|
|
|
|
|
|
|
async def rag_generate_rule(query : str):
|
|
|
if not query:
|
|
|
return "请输入问题"
|
|
|
# 连接数据库
|
|
|
pgvector_conn = use_pgvector.connect_to_db()
|
|
|
|
|
|
# 将问题向量化
|
|
|
query_emb = await use_opanai.generation_vector("query")
|
|
|
|
|
|
# 根据问题关联数据库表,得到table_name_list
|
|
|
similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4)
|
|
|
rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2)
|
|
|
|
|
|
table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')]
|
|
|
print(f"[table_name_list]: {table_name_list}")
|
|
|
|
|
|
# 获得相关表的schema
|
|
|
db_client = await use_mysql.get_db()
|
|
|
schema = await search_desc_by_table_names(table_name_list, db_client)
|
|
|
print(f"[schema]: {schema}")
|
|
|
|
|
|
# 根据问题搜索相关案例
|
|
|
similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3)
|
|
|
print(f"[similar_case]: {similar_case}")
|
|
|
|
|
|
# 询问大模型生成SQL
|
|
|
prompt = [
|
|
|
{"role": "system", "content": f"""
|
|
|
你是精通postgresql的专家,十分擅长利用用户的问题,和库表schema,生成合适的SQL。最终结果里只需要返回SQL,不要解释。
|
|
|
|
|
|
请根据用户问题,生成SQL,请勿返回其他内容。
|
|
|
|
|
|
schema: {schema}
|
|
|
|
|
|
参考案例: {similar_case}
|
|
|
"""},
|
|
|
{"role": "user", "content": query}
|
|
|
]
|
|
|
print(f"[prompt]: {prompt}")
|
|
|
ans = await use_opanai.generation_rule(prompt)
|
|
|
|
|
|
return ans
|
|
|
|
|
|
async def main():
|
|
|
ans = await rag_generate_rule("“投资规模结构分解批次(备份)”表BAK_PC_IP_SCALE_STRUCT_RESOLVE中“业务板块类型”不合规的记录")
|
|
|
print(f"[answer]: {ans}")
|
|
|
|
|
|
# 使用asyncio运行异步函数
|
|
|
import asyncio
|
|
|
if __name__ == "__main__":
|
|
|
asyncio.run(main())
|