|
|
|
|
@ -2,12 +2,12 @@ import array
|
|
|
|
|
import json
|
|
|
|
|
import re
|
|
|
|
|
import uuid
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import jieba.posseg as pseg # type: ignore
|
|
|
|
|
import numpy
|
|
|
|
|
import oracledb
|
|
|
|
|
from oracledb.connection import Connection
|
|
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
@ -70,6 +70,7 @@ class OracleVector(BaseVector):
|
|
|
|
|
super().__init__(collection_name)
|
|
|
|
|
self.pool = self._create_connection_pool(config)
|
|
|
|
|
self.table_name = f"embedding_{collection_name}"
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
def get_type(self) -> str:
|
|
|
|
|
return VectorType.ORACLE
|
|
|
|
|
@ -107,16 +108,19 @@ class OracleVector(BaseVector):
|
|
|
|
|
outconverter=self.numpy_converter_out,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_connection(self) -> Connection:
|
|
|
|
|
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
|
|
|
|
|
return connection
|
|
|
|
|
|
|
|
|
|
def _create_connection_pool(self, config: OracleVectorConfig):
|
|
|
|
|
pool_params = {
|
|
|
|
|
"user": config.user,
|
|
|
|
|
"password": config.password,
|
|
|
|
|
"dsn": config.dsn,
|
|
|
|
|
"min": 1,
|
|
|
|
|
"max": 50,
|
|
|
|
|
"max": 5,
|
|
|
|
|
"increment": 1,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if config.is_autonomous:
|
|
|
|
|
pool_params.update(
|
|
|
|
|
{
|
|
|
|
|
@ -125,22 +129,8 @@ class OracleVector(BaseVector):
|
|
|
|
|
"wallet_password": config.wallet_password,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return oracledb.create_pool(**pool_params)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def _get_cursor(self):
|
|
|
|
|
conn = self.pool.acquire()
|
|
|
|
|
conn.inputtypehandler = self.input_type_handler
|
|
|
|
|
conn.outputtypehandler = self.output_type_handler
|
|
|
|
|
cur = conn.cursor()
|
|
|
|
|
try:
|
|
|
|
|
yield cur
|
|
|
|
|
finally:
|
|
|
|
|
cur.close()
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
|
|
dimension = len(embeddings[0])
|
|
|
|
|
self._create_collection(dimension)
|
|
|
|
|
@ -162,41 +152,68 @@ class OracleVector(BaseVector):
|
|
|
|
|
numpy.array(embeddings[i]),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.executemany(
|
|
|
|
|
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
|
|
|
|
)
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
conn.inputtypehandler = self.input_type_handler
|
|
|
|
|
conn.outputtypehandler = self.output_type_handler
|
|
|
|
|
# with conn.cursor() as cur:
|
|
|
|
|
# cur.executemany(
|
|
|
|
|
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
|
|
|
|
# )
|
|
|
|
|
# conn.commit()
|
|
|
|
|
for value in values:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
try:
|
|
|
|
|
cur.execute(
|
|
|
|
|
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
|
|
|
|
|
VALUES (:1, :2, :3, :4)""",
|
|
|
|
|
value,
|
|
|
|
|
)
|
|
|
|
|
conn.commit()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
conn.close()
|
|
|
|
|
return pks
|
|
|
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool:
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
|
|
|
|
return cur.fetchone() is not None
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
|
|
|
|
|
return cur.fetchone() is not None
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
|
|
|
|
docs = []
|
|
|
|
|
for record in cur:
|
|
|
|
|
docs.append(Document(page_content=record[1], metadata=record[0]))
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
|
|
|
|
docs = []
|
|
|
|
|
for record in cur:
|
|
|
|
|
docs.append(Document(page_content=record[1], metadata=record[0]))
|
|
|
|
|
self.pool.release(connection=conn)
|
|
|
|
|
conn.close()
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
|
|
if not ids:
|
|
|
|
|
return
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
|
|
"""
|
|
|
|
|
Search the nearest neighbors to a vector.
|
|
|
|
|
|
|
|
|
|
:param query_vector: The input vector to search for similar items.
|
|
|
|
|
:param top_k: The number of nearest neighbors to return, default is 5.
|
|
|
|
|
:return: List of Documents that are nearest to the query vector.
|
|
|
|
|
"""
|
|
|
|
|
top_k = kwargs.get("top_k", 4)
|
|
|
|
|
@ -205,20 +222,25 @@ class OracleVector(BaseVector):
|
|
|
|
|
if document_ids_filter:
|
|
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
|
|
|
|
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(
|
|
|
|
|
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
|
|
|
|
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
|
|
|
|
[numpy.array(query_vector)],
|
|
|
|
|
)
|
|
|
|
|
docs = []
|
|
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
|
|
|
|
for record in cur:
|
|
|
|
|
metadata, text, distance = record
|
|
|
|
|
score = 1 - distance
|
|
|
|
|
metadata["score"] = score
|
|
|
|
|
if score > score_threshold:
|
|
|
|
|
docs.append(Document(page_content=text, metadata=metadata))
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
conn.inputtypehandler = self.input_type_handler
|
|
|
|
|
conn.outputtypehandler = self.output_type_handler
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(
|
|
|
|
|
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
|
|
|
|
AS distance FROM {self.table_name}
|
|
|
|
|
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
|
|
|
|
[numpy.array(query_vector)],
|
|
|
|
|
)
|
|
|
|
|
docs = []
|
|
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
|
|
|
|
for record in cur:
|
|
|
|
|
metadata, text, distance = record
|
|
|
|
|
score = 1 - distance
|
|
|
|
|
metadata["score"] = score
|
|
|
|
|
if score > score_threshold:
|
|
|
|
|
docs.append(Document(page_content=text, metadata=metadata))
|
|
|
|
|
conn.close()
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
|
|
@ -228,7 +250,7 @@ class OracleVector(BaseVector):
|
|
|
|
|
|
|
|
|
|
top_k = kwargs.get("top_k", 5)
|
|
|
|
|
# just not implement fetch by score_threshold now, may be later
|
|
|
|
|
# score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
|
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
|
|
|
|
if len(query) > 0:
|
|
|
|
|
# Check which language the query is in
|
|
|
|
|
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
|
|
|
|
@ -239,7 +261,7 @@ class OracleVector(BaseVector):
|
|
|
|
|
words = pseg.cut(query)
|
|
|
|
|
current_entity = ""
|
|
|
|
|
for word, pos in words:
|
|
|
|
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
|
|
|
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
|
|
|
|
current_entity += word
|
|
|
|
|
else:
|
|
|
|
|
if current_entity:
|
|
|
|
|
@ -260,30 +282,35 @@ class OracleVector(BaseVector):
|
|
|
|
|
for token in all_tokens:
|
|
|
|
|
if token not in stop_words:
|
|
|
|
|
entities.append(token)
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
document_ids_filter = kwargs.get("document_ids_filter")
|
|
|
|
|
where_clause = ""
|
|
|
|
|
if document_ids_filter:
|
|
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
|
|
|
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
|
|
|
|
cur.execute(
|
|
|
|
|
f"select meta, text, embedding FROM {self.table_name}"
|
|
|
|
|
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
|
|
|
|
f"order by score(1) desc fetch first {top_k} rows only",
|
|
|
|
|
[" ACCUM ".join(entities)],
|
|
|
|
|
)
|
|
|
|
|
docs = []
|
|
|
|
|
for record in cur:
|
|
|
|
|
metadata, text, embedding = record
|
|
|
|
|
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
document_ids_filter = kwargs.get("document_ids_filter")
|
|
|
|
|
where_clause = ""
|
|
|
|
|
if document_ids_filter:
|
|
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
|
|
|
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
|
|
|
|
cur.execute(
|
|
|
|
|
f"""select meta, text, embedding FROM {self.table_name}
|
|
|
|
|
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
|
|
|
|
order by score(1) desc fetch first {top_k} rows only""",
|
|
|
|
|
kk=" ACCUM ".join(entities),
|
|
|
|
|
)
|
|
|
|
|
docs = []
|
|
|
|
|
for record in cur:
|
|
|
|
|
metadata, text, embedding = record
|
|
|
|
|
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
|
|
|
|
|
conn.close()
|
|
|
|
|
return docs
|
|
|
|
|
else:
|
|
|
|
|
return [Document(page_content="", metadata={})]
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
def delete(self) -> None:
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def _create_collection(self, dimension: int):
|
|
|
|
|
cache_key = f"vector_indexing_{self._collection_name}"
|
|
|
|
|
@ -293,11 +320,14 @@ class OracleVector(BaseVector):
|
|
|
|
|
if redis_client.get(collection_exist_cache_key):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
|
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
|
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
|
with conn.cursor() as cur:
|
|
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OracleVectorFactory(AbstractVectorFactory):
|
|
|
|
|
|