You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
1.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from util import use_pgvector, use_opanai, use_mysql
from util.use_mysql import search_desc_by_table_names
async def rag_generate_rule(query : str):
if not query:
return "请输入问题"
# 连接数据库
pgvector_conn = use_pgvector.connect_to_db()
# 将问题向量化
query_emb = await use_opanai.generation_vector("query")
# 根据问题关联数据库表得到table_name_list
similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4)
rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2)
table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')]
print(f"[table_name_list]: {table_name_list}")
# 获得相关表的schema
db_client = await use_mysql.get_db()
schema = await search_desc_by_table_names(table_name_list, db_client)
print(f"[schema]: {schema}")
# 根据问题搜索相关案例
similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3)
print(f"[similar_case]: {similar_case}")
# 询问大模型生成SQL
prompt = [
{"role": "system", "content": f"""
你是精通postgresql的专家十分擅长利用用户的问题和库表schema生成合适的SQL。最终结果里只需要返回SQL不要解释。
请根据用户问题生成SQL请勿返回其他内容。
schema: {schema}
参考案例: {similar_case}
"""},
{"role": "user", "content": query}
]
print(f"[prompt]: {prompt}")
ans = await use_opanai.generation_rule(prompt)
return ans
async def main():
ans = await rag_generate_rule("“投资规模结构分解批次备份”表BAK_PC_IP_SCALE_STRUCT_RESOLVE中“业务板块类型”不合规的记录")
print(f"[answer]: {ans}")
# 使用asyncio运行异步函数
import asyncio
if __name__ == "__main__":
asyncio.run(main())