From 753d08b93f9ce5fdf253707ec9c83473df91e453 Mon Sep 17 00:00:00 2001 From: "liuchangsheng@wisdomidata.com" Date: Fri, 20 Jun 2025 14:24:57 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Dify=E3=80=91=20vanna.ai=20=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E8=BF=98=E5=8E=9F=EF=BC=8C=E6=9A=82=E6=97=B6=E4=B8=8D?= =?UTF-8?q?=E7=94=A8ollama?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/extensions/utils/vanna_text2sql.py | 124 +++++++++++++------------ 1 file changed, 64 insertions(+), 60 deletions(-) diff --git a/api/extensions/utils/vanna_text2sql.py b/api/extensions/utils/vanna_text2sql.py index 72514eb9ae..c8b762ff43 100644 --- a/api/extensions/utils/vanna_text2sql.py +++ b/api/extensions/utils/vanna_text2sql.py @@ -71,10 +71,14 @@ class VannaServer: embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042' embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3 - embedding_function = CustomEmbeddingFunction({ - "host": embedding_host, - "embed_model": embedding_model - }) + embedding_function = model.dense.SentenceTransformerEmbeddingFunction( + model_name=embedding_model, + device='cpu' # 'cpu' or 'cuda:0' + ) + # embedding_function = CustomEmbeddingFunction({ + # "host": embedding_host, + # "embed_model": embedding_model + # }) chat_llm = Ollama if llm_type == "ollama": config = { @@ -120,94 +124,94 @@ class VannaServer: # 更新建表DDL语句 def refresh_create_table_ddl_train(self): sql = """ -SELECT - 'CREATE TABLE ' - || C.TABLE_NAME - || ' (' - || C.COLUMN_NAMES +SELECT + 'CREATE TABLE ' + || C.TABLE_NAME + || ' (' + || C.COLUMN_NAMES || ');' || C.COMMENT_COLUMNS || CASE WHEN FK.FOREIGN_KEY_COLUMNS IS NOT NULL THEN FK.FOREIGN_KEY_COLUMNS ELSE '' END || CASE WHEN FK.FOREIGN_KEY_DESC IS NOT NULL THEN FK.FOREIGN_KEY_DESC ELSE '' END - || 'COMMENT ON TABLE ' - || C.TABLE_NAME - || ' IS ''' - || G.DESCRIPTION + || 'COMMENT ON TABLE ' + || C.TABLE_NAME + || ' IS ''' + || G.DESCRIPTION || ''';' AS DDL, C.TABLE_NAME FROM ( - SELECT + SELECT COL.TABLE_NAME, COL.TABLE_SCHEMA, STRING_AGG( - COL.COLUMN_NAME - || ' ' - || COL.DATA_TYPE - || COALESCE('(' || COL.CHARACTER_MAXIMUM_LENGTH || ')', '') - || COALESCE(' DEFAULT ' || COL.COLUMN_DEFAULT, '') - || CASE - WHEN COL.IS_NULLABLE = 'NO' THEN ' NOT NULL' - ELSE '' + COL.COLUMN_NAME + || ' ' + || COL.DATA_TYPE + || COALESCE('(' || COL.CHARACTER_MAXIMUM_LENGTH || ')', '') + || COALESCE(' DEFAULT ' || COL.COLUMN_DEFAULT, '') + || CASE + WHEN COL.IS_NULLABLE = 'NO' THEN ' NOT NULL' + ELSE '' END, ',' ) AS COLUMN_NAMES, STRING_AGG( - 'COMMENT ON COLUMN ' - || COL.TABLE_NAME - || '.' - || COL.COLUMN_NAME - || ' IS ''' - || PGD.DESCRIPTION + 'COMMENT ON COLUMN ' + || COL.TABLE_NAME + || '.' + || COL.COLUMN_NAME + || ' IS ''' + || PGD.DESCRIPTION || ''';', '' ) AS COMMENT_COLUMNS - FROM + FROM PG_CATALOG.PG_STATIO_ALL_TABLES AS ST - INNER JOIN - PG_CATALOG.PG_DESCRIPTION AS PGD + INNER JOIN + PG_CATALOG.PG_DESCRIPTION AS PGD ON PGD.OBJOID = ST.RELID - INNER JOIN - INFORMATION_SCHEMA.COLUMNS AS COL + INNER JOIN + INFORMATION_SCHEMA.COLUMNS AS COL ON ( - COL.TABLE_SCHEMA = ST.SCHEMANAME - AND COL.TABLE_NAME = ST.RELNAME + COL.TABLE_SCHEMA = ST.SCHEMANAME + AND COL.TABLE_NAME = ST.RELNAME AND COL.ORDINAL_POSITION = PGD.OBJSUBID ) - WHERE + WHERE COL.TABLE_SCHEMA = 'public' - GROUP BY - COL.TABLE_SCHEMA, + GROUP BY + COL.TABLE_SCHEMA, COL.TABLE_NAME ) C LEFT JOIN ( - SELECT + SELECT N.NSPNAME AS SCHEMA_NAME, C.RELNAME AS TABLE_NAME, D.DESCRIPTION - FROM + FROM PG_CATALOG.PG_DESCRIPTION D - JOIN - PG_CATALOG.PG_CLASS C + JOIN + PG_CATALOG.PG_CLASS C ON C.OID = D.OBJOID - JOIN - PG_CATALOG.PG_NAMESPACE N + JOIN + PG_CATALOG.PG_NAMESPACE N ON N.OID = C.RELNAMESPACE - WHERE + WHERE C.RELKIND = 'r' AND D.OBJSUBID = 0 -) G -ON G.SCHEMA_NAME = C.TABLE_SCHEMA +) G +ON G.SCHEMA_NAME = C.TABLE_SCHEMA AND G.TABLE_NAME = C.TABLE_NAME LEFT JOIN ( SELECT rel_src.relname AS source_table, STRING_AGG( - 'ALTER TABLE ' + 'ALTER TABLE ' || rel_src.relname - || ' ADD CONSTRAINT ' - || con.conname + || ' ADD CONSTRAINT ' + || con.conname || ' FOREIGN KEY (' - || att_src.attname + || att_src.attname || ') REFERENCES ' || rel_tgt.relname || '(' @@ -217,11 +221,11 @@ LEFT JOIN ( '' ) AS FOREIGN_KEY_COLUMNS, STRING_AGG( - 'COMMENT ON CONSTRAINT ' - || con.conname - || ' ON ' - || rel_src.relname - || ' IS ''' + 'COMMENT ON CONSTRAINT ' + || con.conname + || ' ON ' + || rel_src.relname + || ' IS ''' || d.description || ''';', '' @@ -379,19 +383,19 @@ WHERE C.TABLE_NAME NOT IN ('flyway_table_dict','flyway_schema_history') limit=10000, ) data_texts = {t["question"]: t for t in data_list} - + if bool(exist_doc_data): remove_ids = [item["id"] for item in exist_doc_data if item['text'] in data_texts ] - + if bool(remove_ids): self.vn.milvus_client.delete(collection_name="vannasql", ids=remove_ids) - + for item in data_list: self.vn.train( question=item["question"], sql=item["sql"], ) - + self.vn.milvus_client.refresh_load(collection_name="vannasql") return False