|
|
|
|
@ -1,12 +1,13 @@
|
|
|
|
|
import base64
|
|
|
|
|
import hashlib
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Mapping
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from Crypto.Cipher import AES
|
|
|
|
|
from Crypto.Random import get_random_bytes
|
|
|
|
|
from Crypto.Util.Padding import pad, unpad
|
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
|
|
|
|
|
@ -42,7 +43,7 @@ class SystemOAuthEncrypter:
|
|
|
|
|
# Generate a fixed 256-bit key using SHA-256
|
|
|
|
|
self.key = hashlib.sha256(secret_key.encode()).digest()
|
|
|
|
|
|
|
|
|
|
def encrypt_oauth_params(self, oauth_params: str) -> str:
|
|
|
|
|
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Encrypt OAuth parameters.
|
|
|
|
|
|
|
|
|
|
@ -57,9 +58,6 @@ class SystemOAuthEncrypter:
|
|
|
|
|
ValueError: If oauth_params is invalid
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not oauth_params:
|
|
|
|
|
raise ValueError("oauth_params cannot be empty")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Generate random IV (16 bytes)
|
|
|
|
|
iv = get_random_bytes(16)
|
|
|
|
|
@ -68,7 +66,7 @@ class SystemOAuthEncrypter:
|
|
|
|
|
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
|
|
|
|
|
|
|
|
|
# Encrypt data
|
|
|
|
|
padded_data = pad(oauth_params.encode("utf-8"), AES.block_size)
|
|
|
|
|
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
|
|
|
|
encrypted_data = cipher.encrypt(padded_data)
|
|
|
|
|
|
|
|
|
|
# Combine IV and encrypted data
|
|
|
|
|
@ -80,7 +78,7 @@ class SystemOAuthEncrypter:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
|
|
|
|
|
|
|
|
|
def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]:
|
|
|
|
|
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Decrypt OAuth parameters.
|
|
|
|
|
|
|
|
|
|
@ -120,16 +118,13 @@ class SystemOAuthEncrypter:
|
|
|
|
|
unpadded_data = unpad(decrypted_data, AES.block_size)
|
|
|
|
|
|
|
|
|
|
# Parse JSON
|
|
|
|
|
params_json = unpadded_data.decode("utf-8")
|
|
|
|
|
oauth_params = json.loads(params_json)
|
|
|
|
|
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
|
|
|
|
|
|
|
|
|
if not isinstance(oauth_params, dict):
|
|
|
|
|
raise ValueError("Decrypted data is not a valid dictionary")
|
|
|
|
|
|
|
|
|
|
return oauth_params
|
|
|
|
|
|
|
|
|
|
except (ValueError, TypeError) as e:
|
|
|
|
|
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
|
|
|
|
|
|
|
|
|
@ -166,7 +161,7 @@ def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Convenience functions for backward compatibility
|
|
|
|
|
def encrypt_system_oauth_params(oauth_params: str) -> str:
|
|
|
|
|
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Encrypt OAuth parameters using the global encrypter.
|
|
|
|
|
|
|
|
|
|
@ -179,7 +174,7 @@ def encrypt_system_oauth_params(oauth_params: str) -> str:
|
|
|
|
|
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]:
|
|
|
|
|
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Decrypt OAuth parameters using the global encrypter.
|
|
|
|
|
|
|
|
|
|
|