feat: add AWS Managed IAM Auth support for OpenSearch vector DB

pull/18963/head
Ahmad Zidan 1 year ago
parent a54773fbff
commit ad9e06c962
No known key found for this signature in database
GPG Key ID: 2CEE982320CE9D20

@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@ -33,3 +33,19 @@ class OpenSearchConfig(BaseSettings):
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_USE_AWS_MANAGED_IAM: bool = Field(
description="Whether to use AWS IAM authentication for OpenSearch clusters "
"running in Amazon Managed OpenSearch or OpenSearch Serverless",
default=False,
)
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)",
default="aoss"
)

@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@ -26,6 +25,9 @@ class OpenSearchConfig(BaseModel):
port: int
user: Optional[str] = None
password: Optional[str] = None
use_aws_managed_iam: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
secure: bool = False
@model_validator(mode="before")
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("use_aws_managed_iam"):
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required")
return values
def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.user and self.password:
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
if self.secure:
params["ssl_context"] = self.create_ssl_context()
elif self.use_aws_managed_iam:
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(self._client, actions)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -254,6 +280,9 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
secure=dify_config.OPENSEARCH_SECURE,
)

@ -23,6 +23,63 @@ def setup_mock_redis():
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
class TestOpenSearchConfig:
def test_to_opensearch_params(self):
config = OpenSearchConfig(
host="localhost",
port=9200,
user="admin",
password="password",
secure=True,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
@patch("boto3.Session")
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
def test_to_opensearch_params_with_aws_managed_iam(
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
):
mock_credentials = MagicMock()
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
mock_auth_instance = MagicMock()
mock_aws_signer_auth.return_value = mock_auth_instance
aws_region = "ap-southeast-2"
aws_service = "aoss"
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
port = 9201
config = OpenSearchConfig(
host=host,
port=port,
use_aws_managed_iam=True,
aws_region=aws_region,
aws_service=aws_service,
secure=True,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": host, "port": port}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] is mock_auth_instance
mock_aws_signer_auth.assert_called_once_with(
credentials=mock_credentials, region=aws_region, service=aws_service
)
assert mock_boto_session.return_value.get_credentials.called
class TestOpenSearchVector:
def setup_method(self):
self.collection_name = "test_collection"

@ -523,6 +523,10 @@ OPENSEARCH_PORT=9200
OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless, set to true.
OPENSEARCH_USE_AWS_MANAGED_IAM=false
OPENSEARCH_AWS_REGION=ap-southeast-1
OPENSEARCH_AWS_SERVICE=aoss
# tencent vector configurations, only available when VECTOR_STORE is `tencent`
TENCENT_VECTOR_DB_URL=http://127.0.0.1

@ -228,6 +228,9 @@ x-shared-env: &shared-api-worker-env
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
OPENSEARCH_USE_AWS_MANAGED_IAM: ${OPENSEARCH_USE_AWS_MANAGED_IAM:-false}
OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1}
OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss}
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}

Loading…
Cancel
Save