@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any , Optional
from typing import Any , Literal , Optional
from uuid import uuid4
from opensearchpy import OpenSearch , helpers
from opensearchpy import OpenSearch , Urllib3AWSV4SignerAuth, Urllib3HttpConnection , helpers
from opensearchpy . helpers import BulkIndexError
from pydantic import BaseModel , model_validator
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig ( BaseModel ) :
host : str
port : int
secure : bool = False
auth_method : Literal [ " basic " , " aws_managed_iam " ] = " basic "
user : Optional [ str ] = None
password : Optional [ str ] = None
secure : bool = False
aws_region : Optional [ str ] = None
aws_service : Optional [ str ] = None
@model_validator ( mode = " before " )
@classmethod
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError ( " config OPENSEARCH_HOST is required " )
if not values . get ( " port " ) :
raise ValueError ( " config OPENSEARCH_PORT is required " )
if values . get ( " auth_method " ) == " aws_managed_iam " :
if not values . get ( " aws_region " ) :
raise ValueError ( " config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method " )
if not values . get ( " aws_service " ) :
raise ValueError ( " config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method " )
return values
def create_ssl_context ( self ) - > ssl . SSLContext :
ssl_context = ssl . create_default_context ( )
ssl_context . check_hostname = False
ssl_context . verify_mode = ssl . CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth ( self ) - > Urllib3AWSV4SignerAuth :
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth (
credentials = boto3 . Session ( ) . get_credentials ( ) ,
region = self . aws_region ,
service = self . aws_service , # type: ignore[arg-type]
)
def to_opensearch_params ( self ) - > dict [ str , Any ] :
params = {
" hosts " : [ { " host " : self . host , " port " : self . port } ] ,
" use_ssl " : self . secure ,
" verify_certs " : self . secure ,
" connection_class " : Urllib3HttpConnection ,
" pool_maxsize " : 20 ,
}
if self . user and self . password :
if self . auth_method == " basic " :
logger . info ( " Using basic authentication for OpenSearch Vector DB " )
params [ " http_auth " ] = ( self . user , self . password )
if self . secure :
params [ " ssl_context " ] = self . create_ssl_context ( )
elif self . auth_method == " aws_managed_iam " :
logger . info ( " Using AWS managed IAM role for OpenSearch Vector DB " )
params [ " http_auth " ] = self . create_aws_managed_iam_auth ( )
return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
" _op_type " : " index " ,
" _index " : self . _collection_name . lower ( ) ,
" _id " : uuid4 ( ) . hex ,
" _source " : {
Field . CONTENT_KEY . value : documents [ i ] . page_content ,
Field . VECTOR . value : embeddings [ i ] , # Make sure you pass an array here
Field . METADATA_KEY . value : documents [ i ] . metadata ,
} ,
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self . _client_config . aws_service not in [ " aoss " ] :
action [ " _id " ] = uuid4 ( ) . hex
actions . append ( action )
helpers . bulk ( self . _client , actions )
helpers . bulk (
client = self . _client ,
actions = actions ,
timeout = 30 ,
max_retries = 3 ,
)
def get_ids_by_metadata_field ( self , key : str , value : str ) :
query = { " query " : { " term " : { f " { Field . METADATA_KEY . value } . { key } " : value } } }
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
} ,
}
logger . info ( f " Creating OpenSearch index { self . _collection_name . lower ( ) } " )
self . _client . indices . create ( index = self . _collection_name . lower ( ) , body = index_body )
redis_client . set ( collection_exist_cache_key , 1 , ex = 3600 )
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig (
host = dify_config . OPENSEARCH_HOST or " localhost " ,
port = dify_config . OPENSEARCH_PORT ,
secure = dify_config . OPENSEARCH_SECURE ,
auth_method = dify_config . OPENSEARCH_AUTH_METHOD . value ,
user = dify_config . OPENSEARCH_USER ,
password = dify_config . OPENSEARCH_PASSWORD ,
secure = dify_config . OPENSEARCH_SECURE ,
aws_region = dify_config . OPENSEARCH_AWS_REGION ,
aws_service = dify_config . OPENSEARCH_AWS_SERVICE ,
)
return OpenSearchVector ( collection_name = collection_name , config = open_search_config )