feat: add AWS Managed IAM Auth support for OpenSearch vector DB

pull/18963/head
Ahmad Zidan 1 year ago
parent a54773fbff
commit ad9e06c962
No known key found for this signature in database
GPG Key ID: 2CEE982320CE9D20

@ -1,4 +1,4 @@
from typing import Optional from typing import Literal, Optional
from pydantic import Field, PositiveInt from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -33,3 +33,19 @@ class OpenSearchConfig(BaseSettings):
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False, default=False,
) )
OPENSEARCH_USE_AWS_MANAGED_IAM: bool = Field(
description="Whether to use AWS IAM authentication for OpenSearch clusters "
"running in Amazon Managed OpenSearch or OpenSearch Serverless",
default=False,
)
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)",
default="aoss"
)

@ -1,10 +1,9 @@
import json import json
import logging import logging
import ssl
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4
from opensearchpy import OpenSearch, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -26,6 +25,9 @@ class OpenSearchConfig(BaseModel):
port: int port: int
user: Optional[str] = None user: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
use_aws_managed_iam: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
secure: bool = False secure: bool = False
@model_validator(mode="before") @model_validator(mode="before")
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required") raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"): if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required") raise ValueError("config OPENSEARCH_PORT is required")
if values.get("use_aws_managed_iam"):
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required")
return values return values
def create_ssl_context(self) -> ssl.SSLContext: def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
ssl_context = ssl.create_default_context() import boto3 # type: ignore
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation return Urllib3AWSV4SignerAuth(
return ssl_context credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> dict[str, Any]:
params = { params = {
"hosts": [{"host": self.host, "port": self.port}], "hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure, "use_ssl": self.secure,
"verify_certs": self.secure, "verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
} }
if self.user and self.password: if self.user and self.password:
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password) params["http_auth"] = (self.user, self.password)
if self.secure: elif self.use_aws_managed_iam:
params["ssl_context"] = self.create_ssl_context() logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = { action = {
"_op_type": "index", "_op_type": "index",
"_index": self._collection_name.lower(), "_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": { "_source": {
Field.CONTENT_KEY.value: documents[i].page_content, Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata, Field.METADATA_KEY.value: documents[i].metadata,
}, },
} }
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action) actions.append(action)
helpers.bulk(self._client, actions) helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
}, },
} }
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body) self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -254,6 +280,9 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
port=dify_config.OPENSEARCH_PORT, port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER, user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD, password=dify_config.OPENSEARCH_PASSWORD,
use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
secure=dify_config.OPENSEARCH_SECURE, secure=dify_config.OPENSEARCH_SECURE,
) )

@ -23,6 +23,63 @@ def setup_mock_redis():
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
class TestOpenSearchConfig:
def test_to_opensearch_params(self):
config = OpenSearchConfig(
host="localhost",
port=9200,
user="admin",
password="password",
secure=True,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
@patch("boto3.Session")
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
def test_to_opensearch_params_with_aws_managed_iam(
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
):
mock_credentials = MagicMock()
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
mock_auth_instance = MagicMock()
mock_aws_signer_auth.return_value = mock_auth_instance
aws_region = "ap-southeast-2"
aws_service = "aoss"
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
port = 9201
config = OpenSearchConfig(
host=host,
port=port,
use_aws_managed_iam=True,
aws_region=aws_region,
aws_service=aws_service,
secure=True,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": host, "port": port}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] is mock_auth_instance
mock_aws_signer_auth.assert_called_once_with(
credentials=mock_credentials, region=aws_region, service=aws_service
)
assert mock_boto_session.return_value.get_credentials.called
class TestOpenSearchVector: class TestOpenSearchVector:
def setup_method(self): def setup_method(self):
self.collection_name = "test_collection" self.collection_name = "test_collection"

@ -523,6 +523,10 @@ OPENSEARCH_PORT=9200
OPENSEARCH_USER=admin OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true OPENSEARCH_SECURE=true
# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless, set to true.
OPENSEARCH_USE_AWS_MANAGED_IAM=false
OPENSEARCH_AWS_REGION=ap-southeast-1
OPENSEARCH_AWS_SERVICE=aoss
# tencent vector configurations, only available when VECTOR_STORE is `tencent` # tencent vector configurations, only available when VECTOR_STORE is `tencent`
TENCENT_VECTOR_DB_URL=http://127.0.0.1 TENCENT_VECTOR_DB_URL=http://127.0.0.1

@ -228,6 +228,9 @@ x-shared-env: &shared-api-worker-env
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
OPENSEARCH_USE_AWS_MANAGED_IAM: ${OPENSEARCH_USE_AWS_MANAGED_IAM:-false}
OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1}
OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss}
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}

Loading…
Cancel
Save