|
|
|
@ -24,7 +24,7 @@ from pymilvus.model.base import BaseEmbeddingFunction
|
|
|
|
class CustomEmbeddingFunction(BaseEmbeddingFunction):
|
|
|
|
class CustomEmbeddingFunction(BaseEmbeddingFunction):
|
|
|
|
def __init__(self, config=None):
|
|
|
|
def __init__(self, config=None):
|
|
|
|
model_host = config['host'] if "host" in config else 'http://wsd.wisdomidata.com:19042'
|
|
|
|
model_host = config['host'] if "host" in config else 'http://wsd.wisdomidata.com:19042'
|
|
|
|
self.embed_model = config['embed_model'] if "embed_model" in config else 'bge-m3'
|
|
|
|
self.embed_model = config['embed_model'] if "embed_model" in config else 'BAAI/bge-m3'
|
|
|
|
self.embedding_model = ollama.Client(model_host)
|
|
|
|
self.embedding_model = ollama.Client(model_host)
|
|
|
|
self.keep_alive = config.get('keep_alive', None)
|
|
|
|
self.keep_alive = config.get('keep_alive', None)
|
|
|
|
self.ollama_options = config.get('options', {})
|
|
|
|
self.ollama_options = config.get('options', {})
|
|
|
|
@ -70,7 +70,7 @@ class VannaServer:
|
|
|
|
milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database)
|
|
|
|
milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database)
|
|
|
|
|
|
|
|
|
|
|
|
embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042'
|
|
|
|
embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042'
|
|
|
|
embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3
|
|
|
|
embedding_model = config["embedding_model"] if "embedding_model" in config else "BAAI/bge-m3" # BAAI/bge-m3
|
|
|
|
embedding_function = model.dense.SentenceTransformerEmbeddingFunction(
|
|
|
|
embedding_function = model.dense.SentenceTransformerEmbeddingFunction(
|
|
|
|
model_name=embedding_model,
|
|
|
|
model_name=embedding_model,
|
|
|
|
device='cpu' # 'cpu' or 'cuda:0'
|
|
|
|
device='cpu' # 'cpu' or 'cuda:0'
|
|
|
|
|