From ad9e06c962e850d9091f7775638b2aad203ae83d Mon Sep 17 00:00:00 2001 From: Ahmad Zidan Date: Mon, 28 Apr 2025 15:05:20 +0700 Subject: [PATCH] feat: add AWS Managed IAM Auth support for OpenSearch vector DB --- .../middleware/vdb/opensearch_config.py | 18 +++++- .../vdb/opensearch/opensearch_vector.py | 51 +++++++++++++---- .../vdb/opensearch/test_opensearch.py | 57 +++++++++++++++++++ docker/.env.example | 4 ++ docker/docker-compose.yaml | 3 + 5 files changed, 121 insertions(+), 12 deletions(-) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 81dde4c04d..d802aab07c 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -33,3 +33,19 @@ class OpenSearchConfig(BaseSettings): description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", default=False, ) + + OPENSEARCH_USE_AWS_MANAGED_IAM: bool = Field( + description="Whether to use AWS IAM authentication for OpenSearch clusters " + "running in Amazon Managed OpenSearch or OpenSearch Serverless", + default=False, + ) + + OPENSEARCH_AWS_REGION: Optional[str] = Field( + description="AWS region for OpenSearch (e.g. 'us-west-2')", + default=None, + ) + + OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] = Field( + description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", + default="aoss" + ) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 6636646cff..4f9e7b1fb5 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,10 +1,9 @@ import json import logging -import ssl from typing import Any, Optional from uuid import uuid4 -from opensearchpy import OpenSearch, helpers +from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator @@ -26,6 +25,9 @@ class OpenSearchConfig(BaseModel): port: int user: Optional[str] = None password: Optional[str] = None + use_aws_managed_iam: bool = False + aws_region: Optional[str] = None + aws_service: Optional[str] = None secure: bool = False @model_validator(mode="before") @@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel): raise ValueError("config OPENSEARCH_HOST is required") if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") + if values.get("use_aws_managed_iam"): + if not values.get("aws_region"): + raise ValueError("config OPENSEARCH_AWS_REGION is required") + if not values.get("aws_service"): + raise ValueError("config OPENSEARCH_AWS_SERVICE is required") return values - def create_ssl_context(self) -> ssl.SSLContext: - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation - return ssl_context + def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth: + import boto3 # type: ignore + + return Urllib3AWSV4SignerAuth( + credentials=boto3.Session().get_credentials(), + region=self.aws_region, + service=self.aws_service, # type: ignore[arg-type] + ) def to_opensearch_params(self) -> dict[str, Any]: params = { "hosts": [{"host": self.host, "port": self.port}], "use_ssl": self.secure, "verify_certs": self.secure, + "connection_class": Urllib3HttpConnection, + "pool_maxsize": 20, } + if self.user and self.password: + logger.info("Using basic authentication for OpenSearch Vector DB") + params["http_auth"] = (self.user, self.password) - if self.secure: - params["ssl_context"] = self.create_ssl_context() + elif self.use_aws_managed_iam: + logger.info("Using AWS managed IAM role for OpenSearch Vector DB") + + params["http_auth"] = self.create_aws_managed_iam_auth() + return params @@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector): action = { "_op_type": "index", "_index": self._collection_name.lower(), - "_id": uuid4().hex, "_source": { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, }, } + # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 + if self._client_config.aws_service not in ["aoss"]: + action["_id"] = uuid4().hex actions.append(action) - helpers.bulk(self._client, actions) + helpers.bulk( + client=self._client, + actions=actions, + timeout=30, + max_retries=3, + ) def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} @@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector): }, } + logger.info(f"Creating OpenSearch index {self._collection_name.lower()}") self._client.indices.create(index=self._collection_name.lower(), body=index_body) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -254,6 +280,9 @@ class OpenSearchVectorFactory(AbstractVectorFactory): port=dify_config.OPENSEARCH_PORT, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, + use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM, + aws_region=dify_config.OPENSEARCH_AWS_REGION, + aws_service=dify_config.OPENSEARCH_AWS_SERVICE, secure=dify_config.OPENSEARCH_SECURE, ) diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index 35eed75c2f..ff68d3f8c9 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -23,6 +23,63 @@ def setup_mock_redis(): ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) +class TestOpenSearchConfig: + def test_to_opensearch_params(self): + config = OpenSearchConfig( + host="localhost", + port=9200, + user="admin", + password="password", + secure=True, + ) + + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": "localhost", "port": 9200}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"] == ("admin", "password") + + @patch("boto3.Session") + @patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth") + def test_to_opensearch_params_with_aws_managed_iam( + self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock + ): + mock_credentials = MagicMock() + mock_boto_session.return_value.get_credentials.return_value = mock_credentials + + mock_auth_instance = MagicMock() + mock_aws_signer_auth.return_value = mock_auth_instance + + aws_region = "ap-southeast-2" + aws_service = "aoss" + host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com" + port = 9201 + + config = OpenSearchConfig( + host=host, + port=port, + use_aws_managed_iam=True, + aws_region=aws_region, + aws_service=aws_service, + secure=True, + ) + + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": host, "port": port}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"] is mock_auth_instance + + mock_aws_signer_auth.assert_called_once_with( + credentials=mock_credentials, region=aws_region, service=aws_service + ) + assert mock_boto_session.return_value.get_credentials.called + + class TestOpenSearchVector: def setup_method(self): self.collection_name = "test_collection" diff --git a/docker/.env.example b/docker/.env.example index 83d975cec5..31a54806c0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -523,6 +523,10 @@ OPENSEARCH_PORT=9200 OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin OPENSEARCH_SECURE=true +# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless, set to true. +OPENSEARCH_USE_AWS_MANAGED_IAM=false +OPENSEARCH_AWS_REGION=ap-southeast-1 +OPENSEARCH_AWS_SERVICE=aoss # tencent vector configurations, only available when VECTOR_STORE is `tencent` TENCENT_VECTOR_DB_URL=http://127.0.0.1 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1bf6954299..e0e713b164 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -228,6 +228,9 @@ x-shared-env: &shared-api-worker-env OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} + OPENSEARCH_USE_AWS_MANAGED_IAM: ${OPENSEARCH_USE_AWS_MANAGED_IAM:-false} + OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1} + OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss} TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}