diff --git a/api/commands.py b/api/commands.py index 125ec1b770..9f933a378c 100644 --- a/api/commands.py +++ b/api/commands.py @@ -2,10 +2,11 @@ import base64 import json import logging import secrets -from typing import Optional +from typing import Any, Optional import click from flask import current_app +from pydantic import TypeAdapter from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -1174,12 +1175,12 @@ def setup_system_tool_oauth_client(provider, client_params): try: # json validate click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - json.loads(client_params) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) click.echo(click.style("Client params validated successfully.", fg="green")) click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 1068b07062..ddec7b1329 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -17,7 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) - credential_type: Optional[CredentialType] = CredentialType.API_KEY + credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py index f32b950619..f3c946b95f 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -1,12 +1,13 @@ import base64 import hashlib -import json import logging +from collections.abc import Mapping from typing import Any, Optional from Crypto.Cipher import AES from Crypto.Random import get_random_bytes from Crypto.Util.Padding import pad, unpad +from pydantic import TypeAdapter from configs import dify_config @@ -42,7 +43,7 @@ class SystemOAuthEncrypter: # Generate a fixed 256-bit key using SHA-256 self.key = hashlib.sha256(secret_key.encode()).digest() - def encrypt_oauth_params(self, oauth_params: str) -> str: + def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: """ Encrypt OAuth parameters. @@ -57,9 +58,6 @@ class SystemOAuthEncrypter: ValueError: If oauth_params is invalid """ - if not oauth_params: - raise ValueError("oauth_params cannot be empty") - try: # Generate random IV (16 bytes) iv = get_random_bytes(16) @@ -68,7 +66,7 @@ class SystemOAuthEncrypter: cipher = AES.new(self.key, AES.MODE_CBC, iv) # Encrypt data - padded_data = pad(oauth_params.encode("utf-8"), AES.block_size) + padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) encrypted_data = cipher.encrypt(padded_data) # Combine IV and encrypted data @@ -80,7 +78,7 @@ class SystemOAuthEncrypter: except Exception as e: raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e - def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]: + def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: """ Decrypt OAuth parameters. @@ -120,16 +118,13 @@ class SystemOAuthEncrypter: unpadded_data = unpad(decrypted_data, AES.block_size) # Parse JSON - params_json = unpadded_data.decode("utf-8") - oauth_params = json.loads(params_json) + oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) if not isinstance(oauth_params, dict): raise ValueError("Decrypted data is not a valid dictionary") return oauth_params - except (ValueError, TypeError) as e: - raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e except Exception as e: raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e @@ -166,7 +161,7 @@ def get_system_oauth_encrypter() -> SystemOAuthEncrypter: # Convenience functions for backward compatibility -def encrypt_system_oauth_params(oauth_params: str) -> str: +def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: """ Encrypt OAuth parameters using the global encrypter. @@ -179,7 +174,7 @@ def encrypt_system_oauth_params(oauth_params: str) -> str: return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) -def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]: +def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: """ Decrypt OAuth parameters using the global encrypter. diff --git a/api/models/tools.py b/api/models/tools.py index d6a65885a7..7c8b5853ba 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -331,11 +331,13 @@ class MCPToolProvider(Base): provider_controller = MCPToolProviderController._from_db(self) - return create_provider_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=self.tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], cache=NoOpProviderCredentialCache(), - )[0].decrypt(self.credentials) + ) + + return encrypter.decrypt(self.credentials) # type: ignore class ToolModelInvoke(Base): diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ced2118a6c..430575b532 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,7 @@ import json import logging import re +from collections.abc import Mapping from pathlib import Path from typing import Any, Optional @@ -475,7 +476,7 @@ class BuiltinToolManageService: return user_client is not None and user_client.enabled @staticmethod - def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None: + def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None: """ get builtin tool provider """ @@ -497,7 +498,7 @@ class BuiltinToolManageService: ) .first() ) - oauth_params: dict[str, Any] | None = None + oauth_params: Mapping[str, Any] | None = None if user_client: oauth_params = encrypter.decrypt(user_client.oauth_params) return oauth_params diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py new file mode 100644 index 0000000000..e69de29bb2