You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/tests/unit_tests/services/test_mfa_service.py

362 lines
15 KiB
Python

import json
import unittest
from unittest.mock import Mock, patch
from datetime import datetime
import pytest
from models.account import Account, AccountMFASettings
from services.mfa_service import MFAService
class TestMFAService(unittest.TestCase):
def setUp(self):
self.account = Mock(spec=Account)
self.account.id = "test-account-id"
self.account.email = "test@example.com"
self.mfa_settings = Mock(spec=AccountMFASettings)
self.mfa_settings.account_id = self.account.id
self.mfa_settings.enabled = False
self.mfa_settings.secret = None
self.mfa_settings.backup_codes = None
self.mfa_settings.setup_at = None
def test_generate_secret(self):
"""Test secret generation."""
secret = MFAService.generate_secret()
self.assertIsInstance(secret, str)
self.assertEqual(len(secret), 32) # Base32 length
def test_generate_backup_codes(self):
"""Test backup codes generation."""
codes = MFAService.generate_backup_codes()
self.assertEqual(len(codes), 8)
for code in codes:
self.assertIsInstance(code, str)
self.assertEqual(len(code), 8) # 4 hex bytes = 8 chars
@patch('pyotp.TOTP.verify')
def test_verify_totp_valid(self, mock_verify):
"""Test TOTP verification with valid token."""
mock_verify.return_value = True
result = MFAService.verify_totp("test_secret", "123456")
self.assertTrue(result)
mock_verify.assert_called_once_with("123456", valid_window=1)
@patch('pyotp.TOTP.verify')
def test_verify_totp_invalid(self, mock_verify):
"""Test TOTP verification with invalid token."""
mock_verify.return_value = False
result = MFAService.verify_totp("test_secret", "invalid")
self.assertFalse(result)
@patch('services.mfa_service.db.session')
@patch('models.account.AccountMFASettings.query')
def test_get_or_create_mfa_settings_existing(self, mock_query, mock_session):
"""Test getting existing MFA settings."""
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.get_or_create_mfa_settings(self.account)
self.assertEqual(result, self.mfa_settings)
mock_query.filter_by.assert_called_once_with(account_id=self.account.id)
@patch('services.mfa_service.db.session')
@patch('models.account.AccountMFASettings.query')
@patch('models.account.AccountMFASettings')
def test_get_or_create_mfa_settings_new(self, mock_mfa_class, mock_query, mock_session):
"""Test creating new MFA settings."""
mock_query.filter_by.return_value.first.return_value = None
mock_new_settings = Mock()
mock_mfa_class.return_value = mock_new_settings
result = MFAService.get_or_create_mfa_settings(self.account)
self.assertEqual(result, mock_new_settings)
mock_session.add.assert_called_once_with(mock_new_settings)
mock_session.commit.assert_called_once()
@patch('services.mfa_service.db.session')
def test_verify_backup_code_valid(self, mock_session):
"""Test backup code verification with valid code."""
self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"])
result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234")
self.assertTrue(result)
# Check that the code was removed
remaining_codes = json.loads(self.mfa_settings.backup_codes)
self.assertNotIn("ABCD1234", remaining_codes)
self.assertIn("EFGH5678", remaining_codes)
def test_verify_backup_code_invalid(self):
"""Test backup code verification with invalid code."""
self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"])
result = MFAService.verify_backup_code(self.mfa_settings, "INVALID")
self.assertFalse(result)
def test_verify_backup_code_no_codes(self):
"""Test backup code verification with no backup codes."""
self.mfa_settings.backup_codes = None
result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234")
self.assertFalse(result)
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
@patch('services.mfa_service.MFAService.verify_totp')
@patch('services.mfa_service.MFAService.generate_backup_codes')
@patch('services.mfa_service.db.session')
def test_setup_mfa_success(self, mock_session, mock_gen_codes, mock_verify, mock_get_settings):
"""Test successful MFA setup."""
mock_get_settings.return_value = self.mfa_settings
self.mfa_settings.secret = "test_secret"
mock_verify.return_value = True
mock_gen_codes.return_value = ["CODE1", "CODE2"]
result = MFAService.setup_mfa(self.account, "123456")
self.assertTrue(self.mfa_settings.enabled)
self.assertEqual(self.mfa_settings.backup_codes, json.dumps(["CODE1", "CODE2"]))
self.assertIsInstance(self.mfa_settings.setup_at, datetime)
self.assertEqual(result["backup_codes"], ["CODE1", "CODE2"])
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
def test_setup_mfa_already_enabled(self, mock_get_settings):
"""Test MFA setup when already enabled."""
self.mfa_settings.enabled = True
mock_get_settings.return_value = self.mfa_settings
with self.assertRaises(ValueError) as context:
MFAService.setup_mfa(self.account, "123456")
self.assertIn("already enabled", str(context.exception))
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
def test_setup_mfa_no_secret(self, mock_get_settings):
"""Test MFA setup without secret."""
mock_get_settings.return_value = self.mfa_settings
with self.assertRaises(ValueError) as context:
MFAService.setup_mfa(self.account, "123456")
self.assertIn("secret not generated", str(context.exception))
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
@patch('services.mfa_service.MFAService.verify_totp')
def test_setup_mfa_invalid_token(self, mock_verify, mock_get_settings):
"""Test MFA setup with invalid TOTP token."""
mock_get_settings.return_value = self.mfa_settings
self.mfa_settings.secret = "test_secret"
mock_verify.return_value = False
with self.assertRaises(ValueError) as context:
MFAService.setup_mfa(self.account, "invalid")
self.assertIn("Invalid TOTP token", str(context.exception))
@patch('models.account.AccountMFASettings.query')
def test_is_mfa_required_enabled(self, mock_query):
"""Test MFA requirement check when enabled."""
self.mfa_settings.enabled = True
self.mfa_settings.secret = "test_secret"
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.is_mfa_required(self.account)
self.assertTrue(result)
@patch('models.account.AccountMFASettings.query')
def test_is_mfa_required_disabled(self, mock_query):
"""Test MFA requirement check when disabled."""
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.is_mfa_required(self.account)
self.assertFalse(result)
@patch('models.account.AccountMFASettings.query')
def test_is_mfa_required_no_settings(self, mock_query):
"""Test MFA requirement check with no settings."""
mock_query.filter_by.return_value.first.return_value = None
result = MFAService.is_mfa_required(self.account)
self.assertFalse(result)
@patch('models.account.AccountMFASettings.query')
@patch('services.mfa_service.MFAService.verify_totp')
@patch('services.mfa_service.MFAService.verify_backup_code')
def test_authenticate_with_mfa_totp_success(self, mock_verify_backup, mock_verify_totp, mock_query):
"""Test MFA authentication with valid TOTP."""
self.mfa_settings.enabled = True
self.mfa_settings.secret = "test_secret"
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
mock_verify_totp.return_value = True
result = MFAService.authenticate_with_mfa(self.account, "123456")
self.assertTrue(result)
mock_verify_totp.assert_called_once_with("test_secret", "123456")
mock_verify_backup.assert_not_called()
@patch('models.account.AccountMFASettings.query')
@patch('services.mfa_service.MFAService.verify_totp')
@patch('services.mfa_service.MFAService.verify_backup_code')
def test_authenticate_with_mfa_backup_success(self, mock_verify_backup, mock_verify_totp, mock_query):
"""Test MFA authentication with valid backup code."""
self.mfa_settings.enabled = True
self.mfa_settings.secret = "test_secret"
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
mock_verify_totp.return_value = False
mock_verify_backup.return_value = True
result = MFAService.authenticate_with_mfa(self.account, "BACKUP123")
self.assertTrue(result)
mock_verify_totp.assert_called_once_with("test_secret", "BACKUP123")
mock_verify_backup.assert_called_once_with(self.mfa_settings, "BACKUP123")
@patch('models.account.AccountMFASettings.query')
def test_authenticate_with_mfa_disabled(self, mock_query):
"""Test MFA authentication when disabled."""
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.authenticate_with_mfa(self.account, "123456")
self.assertTrue(result)
@patch('models.account.AccountMFASettings.query')
def test_get_mfa_status_enabled(self, mock_query):
"""Test getting MFA status when enabled."""
self.mfa_settings.enabled = True
self.mfa_settings.setup_at = datetime(2025, 1, 1, 12, 0, 0)
self.mfa_settings.backup_codes = json.dumps(["CODE1", "CODE2"])
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.get_mfa_status(self.account)
expected = {
"enabled": True,
"setup_at": "2025-01-01T12:00:00",
"has_backup_codes": True
}
self.assertEqual(result, expected)
@patch('models.account.AccountMFASettings.query')
def test_get_mfa_status_no_settings(self, mock_query):
"""Test getting MFA status with no settings."""
mock_query.filter_by.return_value.first.return_value = None
result = MFAService.get_mfa_status(self.account)
expected = {
"enabled": False,
"setup_at": None,
"has_backup_codes": False
}
self.assertEqual(result, expected)
@patch('qrcode.QRCode')
@patch('pyotp.TOTP')
def test_generate_qr_code(self, mock_totp_class, mock_qr_class):
"""Test QR code generation."""
# Mock TOTP
mock_totp = Mock()
mock_totp.provisioning_uri.return_value = "otpauth://totp/test"
mock_totp_class.return_value = mock_totp
# Mock QR code
mock_qr = Mock()
mock_img = Mock()
mock_qr.make_image.return_value = mock_img
mock_qr_class.return_value = mock_qr
# Mock image buffer
with patch('io.BytesIO') as mock_buffer, \
patch('base64.b64encode') as mock_b64:
mock_b64.return_value.decode.return_value = "base64data"
result = MFAService.generate_qr_code(self.account, "test_secret")
self.assertEqual(result, "data:image/png;base64,base64data")
mock_totp.provisioning_uri.assert_called_once_with(
name=self.account.email,
issuer_name="Dify"
)
@patch('services.account_service.AccountService.check_account_password')
@patch('models.account.AccountMFASettings.query')
@patch('services.mfa_service.db.session')
def test_disable_mfa_success(self, mock_session, mock_query, mock_check_password):
"""Test successful MFA disable."""
mock_check_password.return_value = True
mock_query.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.disable_mfa(self.account, "correct_password")
self.assertTrue(result)
self.assertFalse(self.mfa_settings.enabled)
self.assertIsNone(self.mfa_settings.secret)
self.assertIsNone(self.mfa_settings.backup_codes)
self.assertIsNone(self.mfa_settings.setup_at)
mock_session.commit.assert_called_once()
@patch('services.account_service.AccountService.check_account_password')
def test_disable_mfa_wrong_password(self, mock_check_password):
"""Test MFA disable with wrong password."""
mock_check_password.return_value = False
result = MFAService.disable_mfa(self.account, "wrong_password")
self.assertFalse(result)
@patch('services.account_service.AccountService.check_account_password')
@patch('models.account.AccountMFASettings.query')
def test_disable_mfa_no_settings(self, mock_query, mock_check_password):
"""Test MFA disable when no settings exist."""
mock_check_password.return_value = True
mock_query.filter_by.return_value.first.return_value = None
result = MFAService.disable_mfa(self.account, "correct_password")
self.assertTrue(result) # Already disabled
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
@patch('services.mfa_service.MFAService.generate_secret')
@patch('services.mfa_service.MFAService.generate_qr_code')
@patch('services.mfa_service.db.session')
def test_generate_mfa_setup_data_success(self, mock_session, mock_gen_qr, mock_gen_secret, mock_get_settings):
"""Test successful MFA setup data generation."""
mock_get_settings.return_value = self.mfa_settings
mock_gen_secret.return_value = "NEWSECRET123"
mock_gen_qr.return_value = "data:image/png;base64,qrdata"
result = MFAService.generate_mfa_setup_data(self.account)
self.assertEqual(result["secret"], "NEWSECRET123")
self.assertEqual(result["qr_code"], "data:image/png;base64,qrdata")
self.assertEqual(self.mfa_settings.secret, "NEWSECRET123")
mock_session.commit.assert_called_once()
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings')
def test_generate_mfa_setup_data_already_enabled(self, mock_get_settings):
"""Test MFA setup data generation when already enabled."""
self.mfa_settings.enabled = True
mock_get_settings.return_value = self.mfa_settings
with self.assertRaises(ValueError) as context:
MFAService.generate_mfa_setup_data(self.account)
self.assertIn("already enabled", str(context.exception))
if __name__ == '__main__':
unittest.main()