import os import json from vanna.ollama import Ollama from vanna.qianwen import QianWenAI_Chat from vanna.deepseek import DeepSeekChat from extensions.utils.rewrite_ask import ask from dotenv import load_dotenv import plotly.io as pio from vanna.milvus import Milvus_VectorStore from pymilvus import MilvusClient,model from collections import defaultdict load_dotenv() # 设置显示后端为浏览器 pio.renderers.default = 'browser' os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" from typing import List import ollama import numpy as np from pymilvus.model.base import BaseEmbeddingFunction # 自定义嵌入式模型(适配milvus向量数据库) class CustomEmbeddingFunction(BaseEmbeddingFunction): def __init__(self, config=None): model_host = config['host'] if "host" in config else 'http://wsd.wisdomidata.com:19042' self.embed_model = config['embed_model'] if "embed_model" in config else 'bge-m3' self.embedding_model = ollama.Client(model_host) self.keep_alive = config.get('keep_alive', None) self.ollama_options = config.get('options', {}) self.num_ctx = self.ollama_options.get('num_ctx', 2048) def __call__(self, texts: List[str]): self._encode(texts) def _encode(self,texts: list[str]) -> list[list[float]]: return [self.embedding_model.embeddings( model=self.embed_model, prompt=text, options=self.ollama_options, keep_alive=self.keep_alive )["embedding"] for text in texts] def encode_documents(self, documents: List[str]) -> List[np.array]: # 将每个嵌入结果转换为 np.ndarray embeddings = self._encode(documents) return [np.array(embedding) for embedding in embeddings] def encode_queries(self, queries: List[str]) -> List[np.array]: embeddings = self._encode(queries) return [np.array(embedding) for embedding in embeddings] class VannaServer: def __init__(self, config): self.config = config self.vn = self._initialize_vn() def _initialize_vn(self): config = self.config supplier = config["supplier"] llm_type = config["llm_type"] model_ = config["model"] api_key = config["api_key"] ollama_host = config["ollama_host"] if "ollama_host" in config else None milvus_uri = config["milvus_uri"] sql_type = config["sql_type"] host = config["host"] if "host" in config else os.getenv("DB_HOST", "localhost") dbname = config["dbname"] if "dbname" in config else os.getenv("DB_NAME", "dify_data") user = config["user"] if "user" in config else os.getenv("DB_USER", "root") password = config["password"] if "password" in config else os.getenv("DB_PASSWORD", "mysql") port = config["port"] if "port" in config else int(os.getenv("DB_PORT", 3306)) milvus_database = config["milvus_database"] if "milvus_database" in config else "test" milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database) embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042' embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3 embedding_function = CustomEmbeddingFunction({ "host": embedding_host, "embed_model": embedding_model }) chat_llm = Ollama if llm_type == "ollama": config = { 'model': model_, # 本地ollama大模型名称 'ollama_host': ollama_host, # 本地ollama大模型服务地址 'milvus_client': milvus_client, # 本地milvus向量数据库服务地址 "n_results": 12, "embedding_function": embedding_function, } else: config = { 'model': model_, # 本地ollama大模型名称 'api_key': api_key, # 本地ollama大模型服务地址 'milvus_client': milvus_client, # 本地milvus向量数据库服务地址 "n_results": 12, "embedding_function": embedding_function, } if llm_type == "tongyi": chat_llm = QianWenAI_Chat elif llm_type == "deepseek": chat_llm = DeepSeekChat MyVanna = make_vanna_class(ChatClass=chat_llm) vn = MyVanna(config) if sql_type == "postgres": vn.connect_to_postgres(host=host, dbname=dbname, user=user, password=password, port=port) elif sql_type == "mysql": vn.connect_to_mysql(host=host, dbname=dbname, user=user, password=password, port=port) return vn def schema_train(self): # The information schema query may need some tweaking depending on your database. This is a good starting point. df_information_schema = self.vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS where table_schema = 'public'") # This will break up the information schema into bite-sized chunks that can be referenced by the LLM plan = self.vn.get_training_plan_generic(df_information_schema) # print(plan) # If you like the plan, then uncomment this and run it to train self.vn.train(plan=plan) # 更新建表DDL语句 def refresh_create_table_ddl_train(self): sql = """ SELECT 'CREATE TABLE ' || C.TABLE_NAME || ' (' || C.COLUMN_NAMES || ');' || C.COMMENT_COLUMNS || CASE WHEN FK.FOREIGN_KEY_COLUMNS IS NOT NULL THEN FK.FOREIGN_KEY_COLUMNS ELSE '' END || CASE WHEN FK.FOREIGN_KEY_DESC IS NOT NULL THEN FK.FOREIGN_KEY_DESC ELSE '' END || 'COMMENT ON TABLE ' || C.TABLE_NAME || ' IS ''' || G.DESCRIPTION || ''';' AS DDL, C.TABLE_NAME FROM ( SELECT COL.TABLE_NAME, COL.TABLE_SCHEMA, STRING_AGG( COL.COLUMN_NAME || ' ' || COL.DATA_TYPE || COALESCE('(' || COL.CHARACTER_MAXIMUM_LENGTH || ')', '') || COALESCE(' DEFAULT ' || COL.COLUMN_DEFAULT, '') || CASE WHEN COL.IS_NULLABLE = 'NO' THEN ' NOT NULL' ELSE '' END, ',' ) AS COLUMN_NAMES, STRING_AGG( 'COMMENT ON COLUMN ' || COL.TABLE_NAME || '.' || COL.COLUMN_NAME || ' IS ''' || PGD.DESCRIPTION || ''';', '' ) AS COMMENT_COLUMNS FROM PG_CATALOG.PG_STATIO_ALL_TABLES AS ST INNER JOIN PG_CATALOG.PG_DESCRIPTION AS PGD ON PGD.OBJOID = ST.RELID INNER JOIN INFORMATION_SCHEMA.COLUMNS AS COL ON ( COL.TABLE_SCHEMA = ST.SCHEMANAME AND COL.TABLE_NAME = ST.RELNAME AND COL.ORDINAL_POSITION = PGD.OBJSUBID ) WHERE COL.TABLE_SCHEMA = 'public' GROUP BY COL.TABLE_SCHEMA, COL.TABLE_NAME ) C LEFT JOIN ( SELECT N.NSPNAME AS SCHEMA_NAME, C.RELNAME AS TABLE_NAME, D.DESCRIPTION FROM PG_CATALOG.PG_DESCRIPTION D JOIN PG_CATALOG.PG_CLASS C ON C.OID = D.OBJOID JOIN PG_CATALOG.PG_NAMESPACE N ON N.OID = C.RELNAMESPACE WHERE C.RELKIND = 'r' AND D.OBJSUBID = 0 ) G ON G.SCHEMA_NAME = C.TABLE_SCHEMA AND G.TABLE_NAME = C.TABLE_NAME LEFT JOIN ( SELECT rel_src.relname AS source_table, STRING_AGG( 'ALTER TABLE ' || rel_src.relname || ' ADD CONSTRAINT ' || con.conname || ' FOREIGN KEY (' || att_src.attname || ') REFERENCES ' || rel_tgt.relname || '(' || att_tgt.attname || ');' , '' ) AS FOREIGN_KEY_COLUMNS, STRING_AGG( 'COMMENT ON CONSTRAINT ' || con.conname || ' ON ' || rel_src.relname || ' IS ''' || d.description || ''';', '' ) AS FOREIGN_KEY_DESC FROM pg_constraint con JOIN pg_class rel_src ON rel_src.oid = con.conrelid JOIN pg_class rel_tgt ON rel_tgt.oid = con.confrelid JOIN pg_attribute att_src ON att_src.attrelid = rel_src.oid AND att_src.attnum = ANY(con.conkey) JOIN pg_attribute att_tgt ON att_tgt.attrelid = rel_tgt.oid AND att_tgt.attnum = ANY(con.confkey) LEFT JOIN pg_description d ON d.objoid = con.oid WHERE con.contype = 'f' GROUP BY rel_src.relname ) FK ON FK.source_table = C.TABLE_NAME WHERE C.TABLE_NAME NOT IN ('flyway_table_dict','flyway_schema_history') """ # The information schema query may need some tweaking depending on your database. This is a good starting point. c_table_ddl_list = self.vn.run_sql(sql) # 将 DataFrame 转换为字典列表 c_table_ddl_records = c_table_ddl_list.to_dict(orient='records') exist_ddl_data = self.vn.milvus_client.query( collection_name="vannaddl", output_fields=["*"], limit=10000, ) exists_list = filter(lambda m: m["ddl"].startswith("CREATE TABLE "), exist_ddl_data) remove_ids = [exist["id"] for exist in exists_list] if len(remove_ids) > 0: self.vn.milvus_client.delete(collection_name="vannaddl", ids=remove_ids) for table_ddl in c_table_ddl_records: self.vn.train(ddl=table_ddl["ddl"]) self.vn.milvus_client.refresh_load(collection_name="vannaddl") def refresh_schema_train(self): exist_doc_data = self.vn.milvus_client.query( collection_name="vannadoc", output_fields=["*"], limit=10000, ) exists_list = filter(lambda m: m["doc"].startswith("The following columns are in the "), exist_doc_data) remove_ids = [exist["id"] for exist in exists_list] if len(remove_ids) > 0: self.vn.milvus_client.delete(collection_name="vannadoc", ids=remove_ids) self.schema_train() self.vn.milvus_client.refresh_load(collection_name="vannadoc") def update_schema_train_list(self,docs : list[str]): exist_doc_data = self.vn.milvus_client.query( collection_name="vannadoc", output_fields=["*"], limit=10000, ) exists_list = filter(lambda m: not m["doc"].startswith("The following columns are in the "), exist_doc_data) remove_ids = [exist["id"] for exist in exists_list] if len(remove_ids) > 0: self.vn.milvus_client.delete(collection_name="vannadoc", ids=remove_ids) dict_docs = self.get_dict_docs() docs.extend(dict_docs) for doc in docs: self.vn.train(documentation=doc) # self.schema_train() self.vn.milvus_client.refresh_load(collection_name="vannadoc") def get_dict_docs(self) -> list[str]: dict_docs = [] sql = "select id,table_name,column_name,column_remark,table_remark,dict_values from flyway_table_dict" c_table_dict_list = self.vn.run_sql(sql) # 将 DataFrame 转换为字典列表 c_table_dict_records = c_table_dict_list.to_dict(orient='records') table_names = list(set(item['table_name'] for item in c_table_dict_records)) grouped = defaultdict(list) for table_dict in c_table_dict_records: table_name = table_dict['table_name'] # 分组依据字段 grouped[table_name].append(table_dict) grouped_dict = dict(grouped) for table_name in table_names: columns_list = grouped_dict[table_name] dict_values = ';'.join(f"字段:{item['column_remark']}({item['column_name']})的值:{item["dict_values"]}" for item in columns_list) column = columns_list[0] doc = f"{column["table_remark"]}表:{column["table_name"]},{dict_values}" dict_docs.append(doc) return dict_docs def vn_train(self, question="", sql="", documentation="", ddl=""): if question and sql: # 训练问答对 self.vn.train( question=question, sql=sql ) elif sql: # You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL. self.vn.train(sql=sql) if documentation: # Sometimes you may want to add documentation about your business terminology or definitions. self.vn.train(documentation=documentation) if ddl: # You can also add DDL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL. self.vn.train(ddl=ddl) def get_training_data(self): training_data = self.vn.get_training_data() # print(training_data) return training_data def ask(self, question, visualize=True, auto_train=True, *args, **kwargs): sql, df, fig = ask(self.vn, question, visualize=visualize, auto_train=auto_train, *args, **kwargs) return sql, df, fig def generate_sql(self, question): return self.vn.generate_sql(question=question) def run_sql(self, sql): return self.vn.run_sql(sql=sql) def training_data_export(self): training_data = self.vn.milvus_client.query( collection_name="vannasql", output_fields=["*"], limit=10000, ) result = [] if training_data is not None: result = [{"question":t['text'], "sql": t['sql']} for t in training_data] return result def training_data_import(self, data_list): empty_items = list(filter( lambda item: item['question'] is None or item['question'] == "" or item['sql'] is None or item['sql'] == "", data_list )) if bool(empty_items): return True exist_doc_data = self.vn.milvus_client.query( collection_name="vannasql", output_fields=["*"], limit=10000, ) data_texts = {t["question"]: t for t in data_list} if bool(exist_doc_data): remove_ids = [item["id"] for item in exist_doc_data if item['text'] in data_texts ] if bool(remove_ids): self.vn.milvus_client.delete(collection_name="vannasql", ids=remove_ids) for item in data_list: self.vn.train( question=item["question"], sql=item["sql"], ) self.vn.milvus_client.refresh_load(collection_name="vannasql") return False def make_vanna_class(ChatClass=Ollama): class MyVanna(Milvus_VectorStore, ChatClass): def __init__(self, config=None): Milvus_VectorStore.__init__(self, config=config) ChatClass.__init__(self, config=config) def is_sql_valid(self, sql: str) -> bool: # Your implementation here return False def generate_query_explanation(self, sql: str): my_prompt = [ self.system_message("You are a helpful assistant that will explain a SQL query"), self.user_message("Explain this SQL query: " + sql), ] return self.submit_prompt(prompt=my_prompt) return MyVanna # 使用示例 if __name__ == '__main__': config = {"supplier": "GITEE"} server = VannaServer(config) # server.schema_train() server.ask("汇总每个类别的销售量和销售额, 并按照销售量进行降序排列")