refactor: update OAuth parameter handling to use Pydantic TypeAdapter for validation and improve type annotations across multiple files

pull/22550/head
Yeuoly 10 months ago
parent b8f79a7cb1
commit 1ee69f8a54

@ -2,10 +2,11 @@ import base64
import json import json
import logging import logging
import secrets import secrets
from typing import Optional from typing import Any, Optional
import click import click
from flask import current_app from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -1174,12 +1175,12 @@ def setup_system_tool_oauth_client(provider, client_params):
try: try:
# json validate # json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
json.loads(client_params) client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green")) click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params) oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green")) click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e: except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

@ -17,7 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict) credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: Optional[CredentialType] = CredentialType.API_KEY credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict)

@ -1,12 +1,13 @@
import base64 import base64
import hashlib import hashlib
import json
import logging import logging
from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config from configs import dify_config
@ -42,7 +43,7 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256 # Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest() self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: str) -> str: def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
""" """
Encrypt OAuth parameters. Encrypt OAuth parameters.
@ -57,9 +58,6 @@ class SystemOAuthEncrypter:
ValueError: If oauth_params is invalid ValueError: If oauth_params is invalid
""" """
if not oauth_params:
raise ValueError("oauth_params cannot be empty")
try: try:
# Generate random IV (16 bytes) # Generate random IV (16 bytes)
iv = get_random_bytes(16) iv = get_random_bytes(16)
@ -68,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv) cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data # Encrypt data
padded_data = pad(oauth_params.encode("utf-8"), AES.block_size) padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data) encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data # Combine IV and encrypted data
@ -80,7 +78,7 @@ class SystemOAuthEncrypter:
except Exception as e: except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]: def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
""" """
Decrypt OAuth parameters. Decrypt OAuth parameters.
@ -120,16 +118,13 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size) unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON # Parse JSON
params_json = unpadded_data.decode("utf-8") oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
oauth_params = json.loads(params_json)
if not isinstance(oauth_params, dict): if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary") raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params return oauth_params
except (ValueError, TypeError) as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
except Exception as e: except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
@ -166,7 +161,7 @@ def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
# Convenience functions for backward compatibility # Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: str) -> str: def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
""" """
Encrypt OAuth parameters using the global encrypter. Encrypt OAuth parameters using the global encrypter.
@ -179,7 +174,7 @@ def encrypt_system_oauth_params(oauth_params: str) -> str:
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]: def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
""" """
Decrypt OAuth parameters using the global encrypter. Decrypt OAuth parameters using the global encrypter.

@ -331,11 +331,13 @@ class MCPToolProvider(Base):
provider_controller = MCPToolProviderController._from_db(self) provider_controller = MCPToolProviderController._from_db(self)
return create_provider_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
)[0].decrypt(self.credentials) )
return encrypter.decrypt(self.credentials) # type: ignore
class ToolModelInvoke(Base): class ToolModelInvoke(Base):

@ -1,6 +1,7 @@
import json import json
import logging import logging
import re import re
from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
@ -475,7 +476,7 @@ class BuiltinToolManageService:
return user_client is not None and user_client.enabled return user_client is not None and user_client.enabled
@staticmethod @staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None: def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
""" """
get builtin tool provider get builtin tool provider
""" """
@ -497,7 +498,7 @@ class BuiltinToolManageService:
) )
.first() .first()
) )
oauth_params: dict[str, Any] | None = None oauth_params: Mapping[str, Any] | None = None
if user_client: if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params) oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params return oauth_params

Loading…
Cancel
Save