You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/extensions/utils/vanna_text2sql.py

426 lines
16 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import json
from vanna.ollama import Ollama
from vanna.qianwen import QianWenAI_Chat
from vanna.deepseek import DeepSeekChat
from extensions.utils.rewrite_ask import ask
from dotenv import load_dotenv
import plotly.io as pio
from vanna.milvus import Milvus_VectorStore
from pymilvus import MilvusClient,model
from collections import defaultdict
load_dotenv()
# 设置显示后端为浏览器
pio.renderers.default = 'browser'
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from typing import List
import ollama
import numpy as np
from pymilvus.model.base import BaseEmbeddingFunction
# 自定义嵌入式模型适配milvus向量数据库
class CustomEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self, config=None):
model_host = config['host'] if "host" in config else 'http://wsd.wisdomidata.com:19042'
self.embed_model = config['embed_model'] if "embed_model" in config else 'bge-m3'
self.embedding_model = ollama.Client(model_host)
self.keep_alive = config.get('keep_alive', None)
self.ollama_options = config.get('options', {})
self.num_ctx = self.ollama_options.get('num_ctx', 2048)
def __call__(self, texts: List[str]):
self._encode(texts)
def _encode(self,texts: list[str]) -> list[list[float]]:
return [self.embedding_model.embeddings(
model=self.embed_model,
prompt=text,
options=self.ollama_options,
keep_alive=self.keep_alive
)["embedding"] for text in texts]
def encode_documents(self, documents: List[str]) -> List[np.array]:
# 将每个嵌入结果转换为 np.ndarray
embeddings = self._encode(documents)
return [np.array(embedding) for embedding in embeddings]
def encode_queries(self, queries: List[str]) -> List[np.array]:
embeddings = self._encode(queries)
return [np.array(embedding) for embedding in embeddings]
class VannaServer:
def __init__(self, config):
self.config = config
self.vn = self._initialize_vn()
def _initialize_vn(self):
config = self.config
supplier = config["supplier"]
llm_type = config["llm_type"]
model_ = config["model"]
api_key = config["api_key"]
ollama_host = config["ollama_host"] if "ollama_host" in config else None
milvus_uri = config["milvus_uri"]
sql_type = config["sql_type"]
host = config["host"] if "host" in config else os.getenv("DB_HOST", "localhost")
dbname = config["dbname"] if "dbname" in config else os.getenv("DB_NAME", "dify_data")
user = config["user"] if "user" in config else os.getenv("DB_USER", "root")
password = config["password"] if "password" in config else os.getenv("DB_PASSWORD", "mysql")
port = config["port"] if "port" in config else int(os.getenv("DB_PORT", 3306))
milvus_database = config["milvus_database"] if "milvus_database" in config else "test"
milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database)
embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042'
embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3
embedding_function = CustomEmbeddingFunction({
"host": embedding_host,
"embed_model": embedding_model
})
chat_llm = Ollama
if llm_type == "ollama":
config = {
'model': model_, # 本地ollama大模型名称
'ollama_host': ollama_host, # 本地ollama大模型服务地址
'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
"n_results": 12,
"embedding_function": embedding_function,
}
else:
config = {
'model': model_, # 本地ollama大模型名称
'api_key': api_key, # 本地ollama大模型服务地址
'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
"n_results": 12,
"embedding_function": embedding_function,
}
if llm_type == "tongyi":
chat_llm = QianWenAI_Chat
elif llm_type == "deepseek":
chat_llm = DeepSeekChat
MyVanna = make_vanna_class(ChatClass=chat_llm)
vn = MyVanna(config)
if sql_type == "postgres":
vn.connect_to_postgres(host=host, dbname=dbname, user=user, password=password, port=port)
elif sql_type == "mysql":
vn.connect_to_mysql(host=host, dbname=dbname, user=user, password=password, port=port)
return vn
def schema_train(self):
# The information schema query may need some tweaking depending on your database. This is a good starting point.
df_information_schema = self.vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS where table_schema = 'public'")
# This will break up the information schema into bite-sized chunks that can be referenced by the LLM
plan = self.vn.get_training_plan_generic(df_information_schema)
# print(plan)
# If you like the plan, then uncomment this and run it to train
self.vn.train(plan=plan)
# 更新建表DDL语句
def refresh_create_table_ddl_train(self):
sql = """
SELECT
'CREATE TABLE '
|| C.TABLE_NAME
|| ' ('
|| C.COLUMN_NAMES
|| ');'
|| C.COMMENT_COLUMNS
|| CASE WHEN FK.FOREIGN_KEY_COLUMNS IS NOT NULL THEN FK.FOREIGN_KEY_COLUMNS ELSE '' END
|| CASE WHEN FK.FOREIGN_KEY_DESC IS NOT NULL THEN FK.FOREIGN_KEY_DESC ELSE '' END
|| 'COMMENT ON TABLE '
|| C.TABLE_NAME
|| ' IS '''
|| G.DESCRIPTION
|| ''';'
AS DDL,
C.TABLE_NAME
FROM (
SELECT
COL.TABLE_NAME,
COL.TABLE_SCHEMA,
STRING_AGG(
COL.COLUMN_NAME
|| ' '
|| COL.DATA_TYPE
|| COALESCE('(' || COL.CHARACTER_MAXIMUM_LENGTH || ')', '')
|| COALESCE(' DEFAULT ' || COL.COLUMN_DEFAULT, '')
|| CASE
WHEN COL.IS_NULLABLE = 'NO' THEN ' NOT NULL'
ELSE ''
END,
','
) AS COLUMN_NAMES,
STRING_AGG(
'COMMENT ON COLUMN '
|| COL.TABLE_NAME
|| '.'
|| COL.COLUMN_NAME
|| ' IS '''
|| PGD.DESCRIPTION
|| ''';',
''
) AS COMMENT_COLUMNS
FROM
PG_CATALOG.PG_STATIO_ALL_TABLES AS ST
INNER JOIN
PG_CATALOG.PG_DESCRIPTION AS PGD
ON PGD.OBJOID = ST.RELID
INNER JOIN
INFORMATION_SCHEMA.COLUMNS AS COL
ON (
COL.TABLE_SCHEMA = ST.SCHEMANAME
AND COL.TABLE_NAME = ST.RELNAME
AND COL.ORDINAL_POSITION = PGD.OBJSUBID
)
WHERE
COL.TABLE_SCHEMA = 'public'
GROUP BY
COL.TABLE_SCHEMA,
COL.TABLE_NAME
) C
LEFT JOIN (
SELECT
N.NSPNAME AS SCHEMA_NAME,
C.RELNAME AS TABLE_NAME,
D.DESCRIPTION
FROM
PG_CATALOG.PG_DESCRIPTION D
JOIN
PG_CATALOG.PG_CLASS C
ON C.OID = D.OBJOID
JOIN
PG_CATALOG.PG_NAMESPACE N
ON N.OID = C.RELNAMESPACE
WHERE
C.RELKIND = 'r'
AND D.OBJSUBID = 0
) G
ON G.SCHEMA_NAME = C.TABLE_SCHEMA
AND G.TABLE_NAME = C.TABLE_NAME
LEFT JOIN (
SELECT rel_src.relname AS source_table,
STRING_AGG(
'ALTER TABLE '
|| rel_src.relname
|| ' ADD CONSTRAINT '
|| con.conname
|| ' FOREIGN KEY ('
|| att_src.attname
|| ') REFERENCES '
|| rel_tgt.relname
|| '('
|| att_tgt.attname
|| ');'
,
''
) AS FOREIGN_KEY_COLUMNS,
STRING_AGG(
'COMMENT ON CONSTRAINT '
|| con.conname
|| ' ON '
|| rel_src.relname
|| ' IS '''
|| d.description
|| ''';',
''
) AS FOREIGN_KEY_DESC
FROM
pg_constraint con
JOIN pg_class rel_src ON rel_src.oid = con.conrelid
JOIN pg_class rel_tgt ON rel_tgt.oid = con.confrelid
JOIN pg_attribute att_src ON att_src.attrelid = rel_src.oid AND att_src.attnum = ANY(con.conkey)
JOIN pg_attribute att_tgt ON att_tgt.attrelid = rel_tgt.oid AND att_tgt.attnum = ANY(con.confkey)
LEFT JOIN pg_description d ON d.objoid = con.oid
WHERE
con.contype = 'f'
GROUP BY
rel_src.relname
) FK ON FK.source_table = C.TABLE_NAME
WHERE C.TABLE_NAME NOT IN ('flyway_table_dict','flyway_schema_history')
"""
# The information schema query may need some tweaking depending on your database. This is a good starting point.
c_table_ddl_list = self.vn.run_sql(sql)
# 将 DataFrame 转换为字典列表
c_table_ddl_records = c_table_ddl_list.to_dict(orient='records')
exist_ddl_data = self.vn.milvus_client.query(
collection_name="vannaddl",
output_fields=["*"],
limit=10000,
)
exists_list = filter(lambda m: m["ddl"].startswith("CREATE TABLE "), exist_ddl_data)
remove_ids = [exist["id"] for exist in exists_list]
if len(remove_ids) > 0:
self.vn.milvus_client.delete(collection_name="vannaddl", ids=remove_ids)
for table_ddl in c_table_ddl_records:
self.vn.train(ddl=table_ddl["ddl"])
self.vn.milvus_client.refresh_load(collection_name="vannaddl")
def refresh_schema_train(self):
exist_doc_data = self.vn.milvus_client.query(
collection_name="vannadoc",
output_fields=["*"],
limit=10000,
)
exists_list = filter(lambda m: m["doc"].startswith("The following columns are in the "), exist_doc_data)
remove_ids = [exist["id"] for exist in exists_list]
if len(remove_ids) > 0:
self.vn.milvus_client.delete(collection_name="vannadoc", ids=remove_ids)
self.schema_train()
self.vn.milvus_client.refresh_load(collection_name="vannadoc")
def update_schema_train_list(self,docs : list[str]):
exist_doc_data = self.vn.milvus_client.query(
collection_name="vannadoc",
output_fields=["*"],
limit=10000,
)
exists_list = filter(lambda m: not m["doc"].startswith("The following columns are in the "), exist_doc_data)
remove_ids = [exist["id"] for exist in exists_list]
if len(remove_ids) > 0:
self.vn.milvus_client.delete(collection_name="vannadoc", ids=remove_ids)
dict_docs = self.get_dict_docs()
docs.extend(dict_docs)
for doc in docs:
self.vn.train(documentation=doc)
# self.schema_train()
self.vn.milvus_client.refresh_load(collection_name="vannadoc")
def get_dict_docs(self) -> list[str]:
dict_docs = []
sql = "select id,table_name,column_name,column_remark,table_remark,dict_values from flyway_table_dict"
c_table_dict_list = self.vn.run_sql(sql)
# 将 DataFrame 转换为字典列表
c_table_dict_records = c_table_dict_list.to_dict(orient='records')
table_names = list(set(item['table_name'] for item in c_table_dict_records))
grouped = defaultdict(list)
for table_dict in c_table_dict_records:
table_name = table_dict['table_name'] # 分组依据字段
grouped[table_name].append(table_dict)
grouped_dict = dict(grouped)
for table_name in table_names:
columns_list = grouped_dict[table_name]
dict_values = ';'.join(f"字段:{item['column_remark']}({item['column_name']})的值:{item["dict_values"]}" for item in columns_list)
column = columns_list[0]
doc = f"{column["table_remark"]}表:{column["table_name"]},{dict_values}"
dict_docs.append(doc)
return dict_docs
def vn_train(self, question="", sql="", documentation="", ddl=""):
if question and sql:
# 训练问答对
self.vn.train(
question=question,
sql=sql
)
elif sql:
# You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
self.vn.train(sql=sql)
if documentation:
# Sometimes you may want to add documentation about your business terminology or definitions.
self.vn.train(documentation=documentation)
if ddl:
# You can also add DDL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
self.vn.train(ddl=ddl)
def get_training_data(self):
training_data = self.vn.get_training_data()
# print(training_data)
return training_data
def ask(self, question, visualize=True, auto_train=True, *args, **kwargs):
sql, df, fig = ask(self.vn, question, visualize=visualize, auto_train=auto_train, *args, **kwargs)
return sql, df, fig
def generate_sql(self, question):
return self.vn.generate_sql(question=question)
def run_sql(self, sql):
return self.vn.run_sql(sql=sql)
def training_data_export(self):
training_data = self.vn.milvus_client.query(
collection_name="vannasql",
output_fields=["*"],
limit=10000,
)
result = []
if training_data is not None:
result = [{"question":t['text'], "sql": t['sql']} for t in training_data]
return result
def training_data_import(self, data_list):
empty_items = list(filter(
lambda item: item['question'] is None or item['question'] == "" or item['sql'] is None or item['sql'] == "",
data_list
))
if bool(empty_items):
return True
exist_doc_data = self.vn.milvus_client.query(
collection_name="vannasql",
output_fields=["*"],
limit=10000,
)
data_texts = {t["question"]: t for t in data_list}
if bool(exist_doc_data):
remove_ids = [item["id"] for item in exist_doc_data if item['text'] in data_texts ]
if bool(remove_ids):
self.vn.milvus_client.delete(collection_name="vannasql", ids=remove_ids)
for item in data_list:
self.vn.train(
question=item["question"],
sql=item["sql"],
)
self.vn.milvus_client.refresh_load(collection_name="vannasql")
return False
def make_vanna_class(ChatClass=Ollama):
class MyVanna(Milvus_VectorStore, ChatClass):
def __init__(self, config=None):
Milvus_VectorStore.__init__(self, config=config)
ChatClass.__init__(self, config=config)
def is_sql_valid(self, sql: str) -> bool:
# Your implementation here
return False
def generate_query_explanation(self, sql: str):
my_prompt = [
self.system_message("You are a helpful assistant that will explain a SQL query"),
self.user_message("Explain this SQL query: " + sql),
]
return self.submit_prompt(prompt=my_prompt)
return MyVanna
# 使用示例
if __name__ == '__main__':
config = {"supplier": "GITEE"}
server = VannaServer(config)
# server.schema_train()
server.ask("汇总每个类别的销售量和销售额, 并按照销售量进行降序排列")