refactor: update OAuth parameter handling to use Pydantic TypeAdapter for validation and improve type annotations across multiple files

pull/22550/head
Yeuoly 10 months ago
parent b8f79a7cb1
commit 1ee69f8a54

@ -2,10 +2,11 @@ import base64
import json
import logging
import secrets
from typing import Optional
from typing import Any, Optional
import click
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
@ -1174,12 +1175,12 @@ def setup_system_tool_oauth_client(provider, client_params):
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
json.loads(client_params)
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params)
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

@ -17,7 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: Optional[CredentialType] = CredentialType.API_KEY
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

@ -1,12 +1,13 @@
import base64
import hashlib
import json
import logging
from collections.abc import Mapping
from typing import Any, Optional
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
@ -42,7 +43,7 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: str) -> str:
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
@ -57,9 +58,6 @@ class SystemOAuthEncrypter:
ValueError: If oauth_params is invalid
"""
if not oauth_params:
raise ValueError("oauth_params cannot be empty")
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
@ -68,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(oauth_params.encode("utf-8"), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -80,7 +78,7 @@ class SystemOAuthEncrypter:
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]:
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
@ -120,16 +118,13 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
params_json = unpadded_data.decode("utf-8")
oauth_params = json.loads(params_json)
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except (ValueError, TypeError) as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
@ -166,7 +161,7 @@ def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: str) -> str:
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
@ -179,7 +174,7 @@ def encrypt_system_oauth_params(oauth_params: str) -> str:
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]:
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.

@ -331,11 +331,13 @@ class MCPToolProvider(Base):
provider_controller = MCPToolProviderController._from_db(self)
return create_provider_encrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)[0].decrypt(self.credentials)
)
return encrypter.decrypt(self.credentials) # type: ignore
class ToolModelInvoke(Base):

@ -1,6 +1,7 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
@ -475,7 +476,7 @@ class BuiltinToolManageService:
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
"""
get builtin tool provider
"""
@ -497,7 +498,7 @@ class BuiltinToolManageService:
)
.first()
)
oauth_params: dict[str, Any] | None = None
oauth_params: Mapping[str, Any] | None = None
if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params

Loading…
Cancel
Save