refactor: tidy up

pull/18963/head
Ahmad Zidan 1 year ago
parent ad9e06c962
commit 2e43b61406
No known key found for this signature in database
GPG Key ID: 2CEE982320CE9D20

@ -1,3 +1,4 @@
import enum
from typing import Literal, Optional
from pydantic import Field, PositiveInt
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""
class AuthMethod(enum.StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
default=9200,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
default=None,
@ -29,23 +48,12 @@ class OpenSearchConfig(BaseSettings):
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_USE_AWS_MANAGED_IAM: bool = Field(
description="Whether to use AWS IAM authentication for OpenSearch clusters "
"running in Amazon Managed OpenSearch or OpenSearch Serverless",
default=False,
)
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Literal["es", "aoss"] = Field(
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)",
default="aoss"
default=None
)

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional
from typing import Any, Literal, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
@ -23,12 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
use_aws_managed_iam: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
secure: bool = False
@model_validator(mode="before")
@classmethod
@ -37,11 +37,11 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("use_aws_managed_iam"):
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required")
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required")
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
@ -62,11 +62,11 @@ class OpenSearchConfig(BaseModel):
"pool_maxsize": 20,
}
if self.user and self.password:
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
elif self.use_aws_managed_iam:
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
@ -278,12 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
secure=dify_config.OPENSEARCH_SECURE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

@ -28,9 +28,9 @@ class TestOpenSearchConfig:
config = OpenSearchConfig(
host="localhost",
port=9200,
secure=True,
user="admin",
password="password",
secure=True,
)
params = config.to_opensearch_params()
@ -60,10 +60,10 @@ class TestOpenSearchConfig:
config = OpenSearchConfig(
host=host,
port=port,
use_aws_managed_iam=True,
secure=True,
auth_method="aws_managed_iam",
aws_region=aws_region,
aws_service=aws_service,
secure=True,
)
params = config.to_opensearch_params()
@ -86,7 +86,7 @@ class TestOpenSearchVector:
self.example_doc_id = "example_doc_id"
self.vector = OpenSearchVector(
collection_name=self.collection_name,
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
)
self.vector._client = MagicMock()

Loading…
Cancel
Save