diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py new file mode 100644 index 0000000000..f32b950619 --- /dev/null +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -0,0 +1,192 @@ +import base64 +import hashlib +import json +import logging +from typing import Any, Optional + +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad, unpad + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class OAuthEncryptionError(Exception): + """OAuth encryption/decryption specific error""" + + pass + + +class SystemOAuthEncrypter: + """ + A simple OAuth parameters encrypter using AES-CBC encryption. + + This class provides methods to encrypt and decrypt OAuth parameters + using AES-CBC mode with a key derived from the application's SECRET_KEY. + """ + + def __init__(self, secret_key: Optional[str] = None): + """ + Initialize the OAuth encrypter. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Raises: + ValueError: If SECRET_KEY is not configured or empty + """ + secret_key = secret_key or dify_config.SECRET_KEY or "" + + # Generate a fixed 256-bit key using SHA-256 + self.key = hashlib.sha256(secret_key.encode()).digest() + + def encrypt_oauth_params(self, oauth_params: str) -> str: + """ + Encrypt OAuth parameters. + + Args: + oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + + Returns: + Base64-encoded encrypted string + + Raises: + OAuthEncryptionError: If encryption fails + ValueError: If oauth_params is invalid + """ + + if not oauth_params: + raise ValueError("oauth_params cannot be empty") + + try: + # Generate random IV (16 bytes) + iv = get_random_bytes(16) + + # Create AES cipher (CBC mode) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Encrypt data + padded_data = pad(oauth_params.encode("utf-8"), AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + + # Combine IV and encrypted data + combined = iv + encrypted_data + + # Return base64 encoded string + return base64.b64encode(combined).decode() + + except Exception as e: + raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + + def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]: + """ + Decrypt OAuth parameters. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + + Raises: + OAuthEncryptionError: If decryption fails + ValueError: If encrypted_data is invalid + """ + if not isinstance(encrypted_data, str): + raise ValueError("encrypted_data must be a string") + + if not encrypted_data: + raise ValueError("encrypted_data cannot be empty") + + try: + # Base64 decode + combined = base64.b64decode(encrypted_data) + + # Check minimum length (IV + at least one AES block) + if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data + raise ValueError("Invalid encrypted data format") + + # Separate IV and encrypted data + iv = combined[:16] + encrypted_data_bytes = combined[16:] + + # Create AES cipher + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Decrypt data + decrypted_data = cipher.decrypt(encrypted_data_bytes) + unpadded_data = unpad(decrypted_data, AES.block_size) + + # Parse JSON + params_json = unpadded_data.decode("utf-8") + oauth_params = json.loads(params_json) + + if not isinstance(oauth_params, dict): + raise ValueError("Decrypted data is not a valid dictionary") + + return oauth_params + + except (ValueError, TypeError) as e: + raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + except Exception as e: + raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + + +# Factory function for creating encrypter instances +def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: + """ + Create an OAuth encrypter instance. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Returns: + SystemOAuthEncrypter instance + """ + return SystemOAuthEncrypter(secret_key=secret_key) + + +# Global encrypter instance (for backward compatibility) +_oauth_encrypter: Optional[SystemOAuthEncrypter] = None + + +def get_system_oauth_encrypter() -> SystemOAuthEncrypter: + """ + Get the global OAuth encrypter instance. + + Returns: + SystemOAuthEncrypter instance + """ + global _oauth_encrypter + if _oauth_encrypter is None: + _oauth_encrypter = SystemOAuthEncrypter() + return _oauth_encrypter + + +# Convenience functions for backward compatibility +def encrypt_system_oauth_params(oauth_params: str) -> str: + """ + Encrypt OAuth parameters using the global encrypter. + + Args: + oauth_params: OAuth parameters dictionary + + Returns: + Base64-encoded encrypted string + """ + return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + + +def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]: + """ + Decrypt OAuth parameters using the global encrypter. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + """ + return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) diff --git a/api/models/tools.py b/api/models/tools.py index 34bc97d006..a1b9c0710e 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -35,10 +35,6 @@ class ToolOAuthSystemClient(Base): # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) - @property - def oauth_params(self) -> dict: - return cast(dict, json.loads(self.encrypted_oauth_params)) - # tenant level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthTenantClient(Base): diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 66157fb6b6..4229e3bb9c 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -26,6 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter +from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient @@ -492,7 +493,7 @@ class BuiltinToolManageService: .first() ) if system_client: - oauth_params = encrypter.decrypt(system_client.oauth_params) + oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) return oauth_params