|
|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
from typing import Any, Literal, Optional
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
|
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
|
|
|
|
@ -23,12 +23,12 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
class OpenSearchConfig(BaseModel):
|
|
|
|
|
host: str
|
|
|
|
|
port: int
|
|
|
|
|
secure: bool = False
|
|
|
|
|
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
|
|
|
|
user: Optional[str] = None
|
|
|
|
|
password: Optional[str] = None
|
|
|
|
|
use_aws_managed_iam: bool = False
|
|
|
|
|
aws_region: Optional[str] = None
|
|
|
|
|
aws_service: Optional[str] = None
|
|
|
|
|
secure: bool = False
|
|
|
|
|
|
|
|
|
|
@model_validator(mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
@ -37,11 +37,11 @@ class OpenSearchConfig(BaseModel):
|
|
|
|
|
raise ValueError("config OPENSEARCH_HOST is required")
|
|
|
|
|
if not values.get("port"):
|
|
|
|
|
raise ValueError("config OPENSEARCH_PORT is required")
|
|
|
|
|
if values.get("use_aws_managed_iam"):
|
|
|
|
|
if values.get("auth_method") == "aws_managed_iam":
|
|
|
|
|
if not values.get("aws_region"):
|
|
|
|
|
raise ValueError("config OPENSEARCH_AWS_REGION is required")
|
|
|
|
|
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
|
|
|
|
|
if not values.get("aws_service"):
|
|
|
|
|
raise ValueError("config OPENSEARCH_AWS_SERVICE is required")
|
|
|
|
|
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
|
|
|
|
@ -62,11 +62,11 @@ class OpenSearchConfig(BaseModel):
|
|
|
|
|
"pool_maxsize": 20,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if self.user and self.password:
|
|
|
|
|
if self.auth_method == "basic":
|
|
|
|
|
logger.info("Using basic authentication for OpenSearch Vector DB")
|
|
|
|
|
|
|
|
|
|
params["http_auth"] = (self.user, self.password)
|
|
|
|
|
elif self.use_aws_managed_iam:
|
|
|
|
|
elif self.auth_method == "aws_managed_iam":
|
|
|
|
|
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
|
|
|
|
|
|
|
|
|
params["http_auth"] = self.create_aws_managed_iam_auth()
|
|
|
|
|
@ -278,12 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
|
|
|
|
open_search_config = OpenSearchConfig(
|
|
|
|
|
host=dify_config.OPENSEARCH_HOST or "localhost",
|
|
|
|
|
port=dify_config.OPENSEARCH_PORT,
|
|
|
|
|
secure=dify_config.OPENSEARCH_SECURE,
|
|
|
|
|
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
|
|
|
|
user=dify_config.OPENSEARCH_USER,
|
|
|
|
|
password=dify_config.OPENSEARCH_PASSWORD,
|
|
|
|
|
use_aws_managed_iam=dify_config.OPENSEARCH_USE_AWS_MANAGED_IAM,
|
|
|
|
|
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
|
|
|
|
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
|
|
|
|
|
secure=dify_config.OPENSEARCH_SECURE,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
|
|
|
|
|