diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index d802aab07c..55817eca3e 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,3 +1,4 @@ +import enum from typing import Literal, Optional from pydantic import Field, PositiveInt @@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings): Configuration settings for OpenSearch """ + class AuthMethod(enum.StrEnum): + """ + Authentication method for OpenSearch + """ + + BASIC = "basic" + AWS_MANAGED_IAM = "aws_managed_iam" + OPENSEARCH_HOST: Optional[str] = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, @@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings): default=9200, ) + OPENSEARCH_SECURE: bool = Field( + description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", + default=False, + ) + + OPENSEARCH_AUTH_METHOD: AuthMethod = Field( + description="Authentication method for OpenSearch connection (default is 'basic')", + default=AuthMethod.BASIC, + ) + OPENSEARCH_USER: Optional[str] = Field( description="Username for authenticating with OpenSearch", default=None, @@ -29,23 +48,12 @@ class OpenSearchConfig(BaseSettings): default=None, ) - OPENSEARCH_SECURE: bool = Field( - description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", - default=False, - ) - - OPENSEARCH_USE_AWS_MANAGED_IAM: bool = Field( - description="Whether to use AWS IAM authentication for OpenSearch clusters " - "running in Amazon Managed OpenSearch or OpenSearch Serverless", - default=False, - ) - OPENSEARCH_AWS_REGION: Optional[str] = Field( description="AWS region for OpenSearch (e.g. 'us-west-2')", default=None, ) - OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] = Field( + OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field( description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", - default="aoss" + default=None ) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 4f9e7b1fb5..e23b8d197f 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any, Literal, Optional from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -23,12 +23,12 @@ logger = logging.getLogger(__name__) class OpenSearchConfig(BaseModel): host: str port: int + secure: bool = False + auth_method: Literal["basic", "aws_managed_iam"] = "basic" user: Optional[str] = None password: Optional[str] = None - use_aws_managed_iam: bool = False aws_region: Optional[str] = None aws_service: Optional[str] = None - secure: bool = False @model_validator(mode="before") @classmethod @@ -37,11 +37,11 @@ class OpenSearchConfig(BaseModel): raise ValueError("config OPENSEARCH_HOST is required") if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") - if values.get("use_aws_managed_iam"): + if values.get("auth_method") == "aws_managed_iam": if not values.get("aws_region"): - raise ValueError("config OPENSEARCH_AWS_REGION is required") + raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method") if not values.get("aws_service"): - raise ValueError("config OPENSEARCH_AWS_SERVICE is required") + raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method") return values def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth: @@ -62,11 +62,11 @@ class OpenSearchConfig(BaseModel): "pool_maxsize": 20, } - if self.user and self.password: + if self.auth_method == "basic": logger.info("Using basic authentication for OpenSearch Vector DB") params["http_auth"] = (self.user, self.password) - elif self.use_aws_managed_iam: + elif self.auth_method == "aws_managed_iam": logger.info("Using AWS managed IAM role for OpenSearch Vector DB") params["http_auth"] = self.create_aws_managed_iam_auth() @@ -278,12 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory): open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, + secure=dify_config.OPENSEARCH_SECURE, + auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, - use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM, aws_region=dify_config.OPENSEARCH_AWS_REGION, aws_service=dify_config.OPENSEARCH_AWS_SERVICE, - secure=dify_config.OPENSEARCH_SECURE, ) return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index ff68d3f8c9..2d44dd2924 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -28,9 +28,9 @@ class TestOpenSearchConfig: config = OpenSearchConfig( host="localhost", port=9200, + secure=True, user="admin", password="password", - secure=True, ) params = config.to_opensearch_params() @@ -60,10 +60,10 @@ class TestOpenSearchConfig: config = OpenSearchConfig( host=host, port=port, - use_aws_managed_iam=True, + secure=True, + auth_method="aws_managed_iam", aws_region=aws_region, aws_service=aws_service, - secure=True, ) params = config.to_opensearch_params() @@ -86,7 +86,7 @@ class TestOpenSearchVector: self.example_doc_id = "example_doc_id" self.vector = OpenSearchVector( collection_name=self.collection_name, - config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), + config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"), ) self.vector._client = MagicMock()