You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
362 lines
15 KiB
Python
362 lines
15 KiB
Python
import json
|
|
import unittest
|
|
from unittest.mock import Mock, patch
|
|
from datetime import datetime
|
|
|
|
import pytest
|
|
|
|
from models.account import Account, AccountMFASettings
|
|
from services.mfa_service import MFAService
|
|
|
|
|
|
class TestMFAService(unittest.TestCase):
|
|
def setUp(self):
|
|
self.account = Mock(spec=Account)
|
|
self.account.id = "test-account-id"
|
|
self.account.email = "test@example.com"
|
|
|
|
self.mfa_settings = Mock(spec=AccountMFASettings)
|
|
self.mfa_settings.account_id = self.account.id
|
|
self.mfa_settings.enabled = False
|
|
self.mfa_settings.secret = None
|
|
self.mfa_settings.backup_codes = None
|
|
self.mfa_settings.setup_at = None
|
|
|
|
def test_generate_secret(self):
|
|
"""Test secret generation."""
|
|
secret = MFAService.generate_secret()
|
|
self.assertIsInstance(secret, str)
|
|
self.assertEqual(len(secret), 32) # Base32 length
|
|
|
|
def test_generate_backup_codes(self):
|
|
"""Test backup codes generation."""
|
|
codes = MFAService.generate_backup_codes()
|
|
self.assertEqual(len(codes), 8)
|
|
for code in codes:
|
|
self.assertIsInstance(code, str)
|
|
self.assertEqual(len(code), 8) # 4 hex bytes = 8 chars
|
|
|
|
@patch('pyotp.TOTP.verify')
|
|
def test_verify_totp_valid(self, mock_verify):
|
|
"""Test TOTP verification with valid token."""
|
|
mock_verify.return_value = True
|
|
|
|
result = MFAService.verify_totp("test_secret", "123456")
|
|
|
|
self.assertTrue(result)
|
|
mock_verify.assert_called_once_with("123456", valid_window=1)
|
|
|
|
@patch('pyotp.TOTP.verify')
|
|
def test_verify_totp_invalid(self, mock_verify):
|
|
"""Test TOTP verification with invalid token."""
|
|
mock_verify.return_value = False
|
|
|
|
result = MFAService.verify_totp("test_secret", "invalid")
|
|
|
|
self.assertFalse(result)
|
|
|
|
@patch('services.mfa_service.db.session')
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_get_or_create_mfa_settings_existing(self, mock_query, mock_session):
|
|
"""Test getting existing MFA settings."""
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
|
|
result = MFAService.get_or_create_mfa_settings(self.account)
|
|
|
|
self.assertEqual(result, self.mfa_settings)
|
|
mock_query.filter_by.assert_called_once_with(account_id=self.account.id)
|
|
|
|
@patch('services.mfa_service.db.session')
|
|
@patch('models.account.AccountMFASettings.query')
|
|
@patch('models.account.AccountMFASettings')
|
|
def test_get_or_create_mfa_settings_new(self, mock_mfa_class, mock_query, mock_session):
|
|
"""Test creating new MFA settings."""
|
|
mock_query.filter_by.return_value.first.return_value = None
|
|
mock_new_settings = Mock()
|
|
mock_mfa_class.return_value = mock_new_settings
|
|
|
|
result = MFAService.get_or_create_mfa_settings(self.account)
|
|
|
|
self.assertEqual(result, mock_new_settings)
|
|
mock_session.add.assert_called_once_with(mock_new_settings)
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@patch('services.mfa_service.db.session')
|
|
def test_verify_backup_code_valid(self, mock_session):
|
|
"""Test backup code verification with valid code."""
|
|
self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"])
|
|
|
|
result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234")
|
|
|
|
self.assertTrue(result)
|
|
# Check that the code was removed
|
|
remaining_codes = json.loads(self.mfa_settings.backup_codes)
|
|
self.assertNotIn("ABCD1234", remaining_codes)
|
|
self.assertIn("EFGH5678", remaining_codes)
|
|
|
|
def test_verify_backup_code_invalid(self):
|
|
"""Test backup code verification with invalid code."""
|
|
self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"])
|
|
|
|
result = MFAService.verify_backup_code(self.mfa_settings, "INVALID")
|
|
|
|
self.assertFalse(result)
|
|
|
|
def test_verify_backup_code_no_codes(self):
|
|
"""Test backup code verification with no backup codes."""
|
|
self.mfa_settings.backup_codes = None
|
|
|
|
result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234")
|
|
|
|
self.assertFalse(result)
|
|
|
|
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
|
|
@patch('services.mfa_service.MFAService.verify_totp')
|
|
@patch('services.mfa_service.MFAService.generate_backup_codes')
|
|
@patch('services.mfa_service.db.session')
|
|
def test_setup_mfa_success(self, mock_session, mock_gen_codes, mock_verify, mock_get_settings):
|
|
"""Test successful MFA setup."""
|
|
mock_get_settings.return_value = self.mfa_settings
|
|
self.mfa_settings.secret = "test_secret"
|
|
mock_verify.return_value = True
|
|
mock_gen_codes.return_value = ["CODE1", "CODE2"]
|
|
|
|
result = MFAService.setup_mfa(self.account, "123456")
|
|
|
|
self.assertTrue(self.mfa_settings.enabled)
|
|
self.assertEqual(self.mfa_settings.backup_codes, json.dumps(["CODE1", "CODE2"]))
|
|
self.assertIsInstance(self.mfa_settings.setup_at, datetime)
|
|
self.assertEqual(result["backup_codes"], ["CODE1", "CODE2"])
|
|
|
|
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
|
|
def test_setup_mfa_already_enabled(self, mock_get_settings):
|
|
"""Test MFA setup when already enabled."""
|
|
self.mfa_settings.enabled = True
|
|
mock_get_settings.return_value = self.mfa_settings
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
MFAService.setup_mfa(self.account, "123456")
|
|
|
|
self.assertIn("already enabled", str(context.exception))
|
|
|
|
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
|
|
def test_setup_mfa_no_secret(self, mock_get_settings):
|
|
"""Test MFA setup without secret."""
|
|
mock_get_settings.return_value = self.mfa_settings
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
MFAService.setup_mfa(self.account, "123456")
|
|
|
|
self.assertIn("secret not generated", str(context.exception))
|
|
|
|
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
|
|
@patch('services.mfa_service.MFAService.verify_totp')
|
|
def test_setup_mfa_invalid_token(self, mock_verify, mock_get_settings):
|
|
"""Test MFA setup with invalid TOTP token."""
|
|
mock_get_settings.return_value = self.mfa_settings
|
|
self.mfa_settings.secret = "test_secret"
|
|
mock_verify.return_value = False
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
MFAService.setup_mfa(self.account, "invalid")
|
|
|
|
self.assertIn("Invalid TOTP token", str(context.exception))
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_is_mfa_required_enabled(self, mock_query):
|
|
"""Test MFA requirement check when enabled."""
|
|
self.mfa_settings.enabled = True
|
|
self.mfa_settings.secret = "test_secret"
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
|
|
result = MFAService.is_mfa_required(self.account)
|
|
|
|
self.assertTrue(result)
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_is_mfa_required_disabled(self, mock_query):
|
|
"""Test MFA requirement check when disabled."""
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
|
|
result = MFAService.is_mfa_required(self.account)
|
|
|
|
self.assertFalse(result)
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_is_mfa_required_no_settings(self, mock_query):
|
|
"""Test MFA requirement check with no settings."""
|
|
mock_query.filter_by.return_value.first.return_value = None
|
|
|
|
result = MFAService.is_mfa_required(self.account)
|
|
|
|
self.assertFalse(result)
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
@patch('services.mfa_service.MFAService.verify_totp')
|
|
@patch('services.mfa_service.MFAService.verify_backup_code')
|
|
def test_authenticate_with_mfa_totp_success(self, mock_verify_backup, mock_verify_totp, mock_query):
|
|
"""Test MFA authentication with valid TOTP."""
|
|
self.mfa_settings.enabled = True
|
|
self.mfa_settings.secret = "test_secret"
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
mock_verify_totp.return_value = True
|
|
|
|
result = MFAService.authenticate_with_mfa(self.account, "123456")
|
|
|
|
self.assertTrue(result)
|
|
mock_verify_totp.assert_called_once_with("test_secret", "123456")
|
|
mock_verify_backup.assert_not_called()
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
@patch('services.mfa_service.MFAService.verify_totp')
|
|
@patch('services.mfa_service.MFAService.verify_backup_code')
|
|
def test_authenticate_with_mfa_backup_success(self, mock_verify_backup, mock_verify_totp, mock_query):
|
|
"""Test MFA authentication with valid backup code."""
|
|
self.mfa_settings.enabled = True
|
|
self.mfa_settings.secret = "test_secret"
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
mock_verify_totp.return_value = False
|
|
mock_verify_backup.return_value = True
|
|
|
|
result = MFAService.authenticate_with_mfa(self.account, "BACKUP123")
|
|
|
|
self.assertTrue(result)
|
|
mock_verify_totp.assert_called_once_with("test_secret", "BACKUP123")
|
|
mock_verify_backup.assert_called_once_with(self.mfa_settings, "BACKUP123")
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_authenticate_with_mfa_disabled(self, mock_query):
|
|
"""Test MFA authentication when disabled."""
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
|
|
result = MFAService.authenticate_with_mfa(self.account, "123456")
|
|
|
|
self.assertTrue(result)
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_get_mfa_status_enabled(self, mock_query):
|
|
"""Test getting MFA status when enabled."""
|
|
self.mfa_settings.enabled = True
|
|
self.mfa_settings.setup_at = datetime(2025, 1, 1, 12, 0, 0)
|
|
self.mfa_settings.backup_codes = json.dumps(["CODE1", "CODE2"])
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
|
|
result = MFAService.get_mfa_status(self.account)
|
|
|
|
expected = {
|
|
"enabled": True,
|
|
"setup_at": "2025-01-01T12:00:00",
|
|
"has_backup_codes": True
|
|
}
|
|
self.assertEqual(result, expected)
|
|
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_get_mfa_status_no_settings(self, mock_query):
|
|
"""Test getting MFA status with no settings."""
|
|
mock_query.filter_by.return_value.first.return_value = None
|
|
|
|
result = MFAService.get_mfa_status(self.account)
|
|
|
|
expected = {
|
|
"enabled": False,
|
|
"setup_at": None,
|
|
"has_backup_codes": False
|
|
}
|
|
self.assertEqual(result, expected)
|
|
|
|
@patch('qrcode.QRCode')
|
|
@patch('pyotp.TOTP')
|
|
def test_generate_qr_code(self, mock_totp_class, mock_qr_class):
|
|
"""Test QR code generation."""
|
|
# Mock TOTP
|
|
mock_totp = Mock()
|
|
mock_totp.provisioning_uri.return_value = "otpauth://totp/test"
|
|
mock_totp_class.return_value = mock_totp
|
|
|
|
# Mock QR code
|
|
mock_qr = Mock()
|
|
mock_img = Mock()
|
|
mock_qr.make_image.return_value = mock_img
|
|
mock_qr_class.return_value = mock_qr
|
|
|
|
# Mock image buffer
|
|
with patch('io.BytesIO') as mock_buffer, \
|
|
patch('base64.b64encode') as mock_b64:
|
|
mock_b64.return_value.decode.return_value = "base64data"
|
|
|
|
result = MFAService.generate_qr_code(self.account, "test_secret")
|
|
|
|
self.assertEqual(result, "data:image/png;base64,base64data")
|
|
mock_totp.provisioning_uri.assert_called_once_with(
|
|
name=self.account.email,
|
|
issuer_name="Dify"
|
|
)
|
|
|
|
@patch('services.account_service.AccountService.check_account_password')
|
|
@patch('models.account.AccountMFASettings.query')
|
|
@patch('services.mfa_service.db.session')
|
|
def test_disable_mfa_success(self, mock_session, mock_query, mock_check_password):
|
|
"""Test successful MFA disable."""
|
|
mock_check_password.return_value = True
|
|
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
|
|
|
|
result = MFAService.disable_mfa(self.account, "correct_password")
|
|
|
|
self.assertTrue(result)
|
|
self.assertFalse(self.mfa_settings.enabled)
|
|
self.assertIsNone(self.mfa_settings.secret)
|
|
self.assertIsNone(self.mfa_settings.backup_codes)
|
|
self.assertIsNone(self.mfa_settings.setup_at)
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@patch('services.account_service.AccountService.check_account_password')
|
|
def test_disable_mfa_wrong_password(self, mock_check_password):
|
|
"""Test MFA disable with wrong password."""
|
|
mock_check_password.return_value = False
|
|
|
|
result = MFAService.disable_mfa(self.account, "wrong_password")
|
|
|
|
self.assertFalse(result)
|
|
|
|
@patch('services.account_service.AccountService.check_account_password')
|
|
@patch('models.account.AccountMFASettings.query')
|
|
def test_disable_mfa_no_settings(self, mock_query, mock_check_password):
|
|
"""Test MFA disable when no settings exist."""
|
|
mock_check_password.return_value = True
|
|
mock_query.filter_by.return_value.first.return_value = None
|
|
|
|
result = MFAService.disable_mfa(self.account, "correct_password")
|
|
|
|
self.assertTrue(result) # Already disabled
|
|
|
|
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
|
|
@patch('services.mfa_service.MFAService.generate_secret')
|
|
@patch('services.mfa_service.MFAService.generate_qr_code')
|
|
@patch('services.mfa_service.db.session')
|
|
def test_generate_mfa_setup_data_success(self, mock_session, mock_gen_qr, mock_gen_secret, mock_get_settings):
|
|
"""Test successful MFA setup data generation."""
|
|
mock_get_settings.return_value = self.mfa_settings
|
|
mock_gen_secret.return_value = "NEWSECRET123"
|
|
mock_gen_qr.return_value = "data:image/png;base64,qrdata"
|
|
|
|
result = MFAService.generate_mfa_setup_data(self.account)
|
|
|
|
self.assertEqual(result["secret"], "NEWSECRET123")
|
|
self.assertEqual(result["qr_code"], "data:image/png;base64,qrdata")
|
|
self.assertEqual(self.mfa_settings.secret, "NEWSECRET123")
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
|
|
def test_generate_mfa_setup_data_already_enabled(self, mock_get_settings):
|
|
"""Test MFA setup data generation when already enabled."""
|
|
self.mfa_settings.enabled = True
|
|
mock_get_settings.return_value = self.mfa_settings
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
MFAService.generate_mfa_setup_data(self.account)
|
|
|
|
self.assertIn("already enabled", str(context.exception))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |