diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 2dcf1710b0..279410ac8c 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -42,6 +42,7 @@ from .vdb.upstash_config import UpstashConfig from .vdb.vastbase_vector_config import VastbaseVectorConfig from .vdb.vikingdb_config import VikingDBConfig from .vdb.weaviate_config import WeaviateConfig +from .vdb.vanna_config import VannaConfig class StorageConfig(BaseSettings): @@ -323,5 +324,6 @@ class MiddlewareConfig( OpenGaussConfig, TableStoreConfig, DatasetQueueMonitorConfig, + VannaConfig, ): pass diff --git a/api/configs/middleware/vdb/vanna_config.py b/api/configs/middleware/vdb/vanna_config.py new file mode 100644 index 0000000000..da96cc3f1e --- /dev/null +++ b/api/configs/middleware/vdb/vanna_config.py @@ -0,0 +1,80 @@ +from typing import Optional + +from pydantic import Field,PositiveInt +from pydantic_settings import BaseSettings + +class VannaConfig(BaseSettings): + """ + Configuration settings for Milvus vector database + """ + + VANNA_EMBEDDING_HOST: Optional[str] = Field( + description="vanna 向量模型地址", + default="http://127.0.0.1:19042", + ) + + VANNA_EMBEDDING_MODEL: Optional[str] = Field( + description="vanna 向量模型名称", + default='bge-m3', + ) + + VANNA_EMBEDDING_TYPE: Optional[str] = Field( + description="vanna 向量模型类型,默认是localhost,可以是ollama或其他类型", + default="localhost", + ) + + VANNA_LLM_TYPE: Optional[str] = Field( + description="vanna 语言模型类型,默认是deepseek,可以是ollama或其他类型", + default="deepseek", + ) + + VANNA_MODEL: str = Field( + description="vanna 语言模型版本,默认是deepseek-coder", + default="deepseek-coder", + ) + + VANNA_API_KEY: str = Field( + description="vanna 大模型API KEY", + default=None, + ) + + VANNA_SQL_TYPE: Optional[str] = Field( + description='vanna 训练数据库类型,默认是 postgres', + default="postgres", + ) + VANNA_DB_USERNAME: Optional[str] = Field( + description='vanna 训练数据库用户名,默认是 postgres', + default='postgres', + ) + VANNA_DB_PASSWORD: Optional[str] = Field( + description='vanna 训练数据库 postgres', + default='difyai123456', + ) + VANNA_DB_HOST: Optional[str] = Field( + description='vanna 训练数据库地址,默认是 localhost', + default='localhost', + ) + VANNA_DB_PORT: PositiveInt = Field( + description='vanna 训练数据库端口号,默认是 5432', + default=5432, + ) + VANNA_DB_DATABASE: Optional[str] = Field( + description='vanna 训练数据库名称,默认是 vanna_demo', + default='vanna_demo', + ) + VANNA_MILVUS_URI: Optional[str] = Field( + description='vanna 训练向量数据库地址,默认是 localhost:19530', + default='localhost:19530', + ) + VANNA_MILVUS_USER: Optional[str] = Field( + description='vanna 训练向量数据库用户名,默认是 vanna_demo', + default='root', + ) + VANNA_MILVUS_PASSWORD: Optional[str] = Field( + description='vanna 训练向量数据库密码,默认是 Milvus', + default='Milvus', + ) + VANNA_MILVUS_DATABASE: str = Field( + description='vanna 训练向量数据库名称,默认是 vanna_demo', + default='vanna_demo', + ) diff --git a/api/extensions/ext_vanna_server.py b/api/extensions/ext_vanna_server.py index 01d2844486..089f08aa31 100644 --- a/api/extensions/ext_vanna_server.py +++ b/api/extensions/ext_vanna_server.py @@ -15,8 +15,11 @@ from datetime import datetime class Config: def __init__(self, supplier): self.embedding_supplier = "SiliconFlow" - self.milvus_uri = dify_config.MILVUS_URI - self.milvus_database = 'vanna_demo' + self.milvus_uri = dify_config.VANNA_MILVUS_URI + self.milvus_database = dify_config.VANNA_MILVUS_DATABASE + self.embedding_host = dify_config.VANNA_EMBEDDING_HOST + self.embedding_model = dify_config.VANNA_EMBEDDING_MODEL + self.embedding_type = dify_config.VANNA_EMBEDDING_TYPE self.supplier = supplier # self.llm_type = 'tongyi' # self.model = 'qwen-max' @@ -24,16 +27,16 @@ class Config: # 本地模型 # self.ollama_host = 'http://wsd.wisdomidata.com:19042' # self.model = 'qwen2:7b' - self.llm_type = 'deepseek' - self.model = 'deepseek-coder' - self.api_key = 'sk-0382990b7a90496c889774b1d3843f90' - self.sql_type = 'postgres' + self.llm_type = dify_config.VANNA_LLM_TYPE + self.model = dify_config.VANNA_MODEL + self.api_key = dify_config.VANNA_API_KEY + self.sql_type = dify_config.VANNA_SQL_TYPE self.sql_config = { - "host": dify_config.DB_HOST, - "dbname": 'vanna_demo', - "user": dify_config.DB_USERNAME, - "password": dify_config.DB_PASSWORD, - "port": dify_config.DB_PORT + "host": dify_config.VANNA_DB_HOST, + "dbname": dify_config.VANNA_DB_DATABASE, + "user": dify_config.VANNA_DB_USERNAME, + "password": dify_config.VANNA_DB_PASSWORD, + "port": dify_config.VANNA_DB_PORT } # 存储不同的 VannaServer 实例 @@ -214,7 +217,7 @@ def init_app(app: DifyApp): @app.route('/api/training/data/import', methods=['POST']) def training_data_import(): - + if 'file' not in request.files: return jsonify({"type": "error", "error": "未上传文件"}), 400 @@ -235,4 +238,4 @@ def init_app(app: DifyApp): return jsonify({'status': 'success'}), 200 except Exception as e: - return jsonify({"type": "error", "error": f"文件解析失败: {str(e)}"}), 500 \ No newline at end of file + return jsonify({"type": "error", "error": f"文件解析失败: {str(e)}"}), 500 diff --git a/api/extensions/utils/vanna_text2sql.py b/api/extensions/utils/vanna_text2sql.py index 311759c087..bb6629e46d 100644 --- a/api/extensions/utils/vanna_text2sql.py +++ b/api/extensions/utils/vanna_text2sql.py @@ -69,16 +69,19 @@ class VannaServer: milvus_database = config["milvus_database"] if "milvus_database" in config else "test" milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database) + embedding_type = config["embedding_type"] embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042' embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3 - # embedding_function = model.dense.SentenceTransformerEmbeddingFunction( - # model_name=embedding_model, - # device='cpu' # 'cpu' or 'cuda:0' - # ) - embedding_function = CustomEmbeddingFunction({ - "host": embedding_host, - "embed_model": embedding_model - }) + if embedding_type == "ollama": + embedding_function = CustomEmbeddingFunction({ + "host": embedding_host, + "embed_model": embedding_model + }) + else: + embedding_function = model.dense.SentenceTransformerEmbeddingFunction( + model_name=embedding_model, + device='cpu' # 'cpu' or 'cuda:0' + ) chat_llm = Ollama if llm_type == "ollama": config = {