|
|
|
|
@ -23,7 +23,8 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
class OpenSearchConfig(BaseModel):
|
|
|
|
|
host: str
|
|
|
|
|
port: int
|
|
|
|
|
secure: bool = False
|
|
|
|
|
secure: bool = False # use_ssl
|
|
|
|
|
verify_certs: bool = True
|
|
|
|
|
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
|
|
|
|
user: Optional[str] = None
|
|
|
|
|
password: Optional[str] = None
|
|
|
|
|
@ -42,6 +43,8 @@ class OpenSearchConfig(BaseModel):
|
|
|
|
|
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
|
|
|
|
|
if not values.get("aws_service"):
|
|
|
|
|
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
|
|
|
|
|
if not values.get("OPENSEARCH_SECURE") and values.get("OPENSEARCH_VERIFY_CERTS"):
|
|
|
|
|
raise ValueError("verify_certs=True requires secure (HTTPS) connection")
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
|
|
|
|
@ -57,7 +60,7 @@ class OpenSearchConfig(BaseModel):
|
|
|
|
|
params = {
|
|
|
|
|
"hosts": [{"host": self.host, "port": self.port}],
|
|
|
|
|
"use_ssl": self.secure,
|
|
|
|
|
"verify_certs": self.secure,
|
|
|
|
|
"verify_certs": self.verify_certs,
|
|
|
|
|
"connection_class": Urllib3HttpConnection,
|
|
|
|
|
"pool_maxsize": 20,
|
|
|
|
|
}
|
|
|
|
|
@ -279,6 +282,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
|
|
|
|
host=dify_config.OPENSEARCH_HOST or "localhost",
|
|
|
|
|
port=dify_config.OPENSEARCH_PORT,
|
|
|
|
|
secure=dify_config.OPENSEARCH_SECURE,
|
|
|
|
|
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
|
|
|
|
|
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
|
|
|
|
user=dify_config.OPENSEARCH_USER,
|
|
|
|
|
password=dify_config.OPENSEARCH_PASSWORD,
|
|
|
|
|
|