refactor: tidy up

pull/18963/head
Ahmad Zidan 1 year ago
parent ad9e06c962
commit 2e43b61406
No known key found for this signature in database
GPG Key ID: 2CEE982320CE9D20

@ -1,3 +1,4 @@
import enum
from typing import Literal, Optional from typing import Literal, Optional
from pydantic import Field, PositiveInt from pydantic import Field, PositiveInt
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch Configuration settings for OpenSearch
""" """
class AuthMethod(enum.StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: Optional[str] = Field( OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None, default=None,
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
default=9200, default=9200,
) )
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,
)
OPENSEARCH_USER: Optional[str] = Field( OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch", description="Username for authenticating with OpenSearch",
default=None, default=None,
@ -29,23 +48,12 @@ class OpenSearchConfig(BaseSettings):
default=None, default=None,
) )
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_USE_AWS_MANAGED_IAM: bool = Field(
description="Whether to use AWS IAM authentication for OpenSearch clusters "
"running in Amazon Managed OpenSearch or OpenSearch Serverless",
default=False,
)
OPENSEARCH_AWS_REGION: Optional[str] = Field( OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')", description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None, default=None,
) )
OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] = Field( OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)",
default="aoss" default=None
) )

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any, Optional from typing import Any, Literal, Optional
from uuid import uuid4 from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
@ -23,12 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel): class OpenSearchConfig(BaseModel):
host: str host: str
port: int port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None user: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
use_aws_managed_iam: bool = False
aws_region: Optional[str] = None aws_region: Optional[str] = None
aws_service: Optional[str] = None aws_service: Optional[str] = None
secure: bool = False
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -37,11 +37,11 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required") raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"): if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required") raise ValueError("config OPENSEARCH_PORT is required")
if values.get("use_aws_managed_iam"): if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"): if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required") raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"): if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required") raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth: def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
@ -62,11 +62,11 @@ class OpenSearchConfig(BaseModel):
"pool_maxsize": 20, "pool_maxsize": 20,
} }
if self.user and self.password: if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB") logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password) params["http_auth"] = (self.user, self.password)
elif self.use_aws_managed_iam: elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB") logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth() params["http_auth"] = self.create_aws_managed_iam_auth()
@ -278,12 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig( open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost", host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT, port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER, user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD, password=dify_config.OPENSEARCH_PASSWORD,
use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM,
aws_region=dify_config.OPENSEARCH_AWS_REGION, aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE, aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
secure=dify_config.OPENSEARCH_SECURE,
) )
return OpenSearchVector(collection_name=collection_name, config=open_search_config) return OpenSearchVector(collection_name=collection_name, config=open_search_config)

@ -28,9 +28,9 @@ class TestOpenSearchConfig:
config = OpenSearchConfig( config = OpenSearchConfig(
host="localhost", host="localhost",
port=9200, port=9200,
secure=True,
user="admin", user="admin",
password="password", password="password",
secure=True,
) )
params = config.to_opensearch_params() params = config.to_opensearch_params()
@ -60,10 +60,10 @@ class TestOpenSearchConfig:
config = OpenSearchConfig( config = OpenSearchConfig(
host=host, host=host,
port=port, port=port,
use_aws_managed_iam=True, secure=True,
auth_method="aws_managed_iam",
aws_region=aws_region, aws_region=aws_region,
aws_service=aws_service, aws_service=aws_service,
secure=True,
) )
params = config.to_opensearch_params() params = config.to_opensearch_params()
@ -86,7 +86,7 @@ class TestOpenSearchVector:
self.example_doc_id = "example_doc_id" self.example_doc_id = "example_doc_id"
self.vector = OpenSearchVector( self.vector = OpenSearchVector(
collection_name=self.collection_name, collection_name=self.collection_name,
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
) )
self.vector._client = MagicMock() self.vector._client = MagicMock()

Loading…
Cancel
Save