diff --git a/api/extensions/utils/vanna_text2sql.py b/api/extensions/utils/vanna_text2sql.py index dbca9cee0b..311759c087 100644 --- a/api/extensions/utils/vanna_text2sql.py +++ b/api/extensions/utils/vanna_text2sql.py @@ -70,15 +70,15 @@ class VannaServer: milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database) embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042' - embedding_model = config["embedding_model"] if "embedding_model" in config else "BAAI/bge-m3" # BAAI/bge-m3 - embedding_function = model.dense.SentenceTransformerEmbeddingFunction( - model_name=embedding_model, - device='cpu' # 'cpu' or 'cuda:0' - ) - # embedding_function = CustomEmbeddingFunction({ - # "host": embedding_host, - # "embed_model": embedding_model - # }) + embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3 + # embedding_function = model.dense.SentenceTransformerEmbeddingFunction( + # model_name=embedding_model, + # device='cpu' # 'cpu' or 'cuda:0' + # ) + embedding_function = CustomEmbeddingFunction({ + "host": embedding_host, + "embed_model": embedding_model + }) chat_llm = Ollama if llm_type == "ollama": config = {