diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index aec21c1bb3..712ce5ef8a 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -14,8 +14,6 @@ from controllers.console.auth.error import ( EmailPasswordLoginLimitError, InvalidEmailError, InvalidTokenError, - MFARequiredError, - MFATokenRequiredError, ) from controllers.console.error import ( AccountBannedError, diff --git a/api/controllers/console/auth/mfa.py b/api/controllers/console/auth/mfa.py index 4c3f76eac1..2241b7f9a3 100644 --- a/api/controllers/console/auth/mfa.py +++ b/api/controllers/console/auth/mfa.py @@ -1,12 +1,8 @@ from typing import cast import flask_login -from flask import request from flask_restful import Resource, reqparse -from controllers.console.auth.error import ( - TokenValidationError, -) from controllers.console.wraps import account_initialization_required from libs.login import login_required from models.account import Account diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 9c79782f47..a76c2585f3 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -389,7 +389,8 @@ api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') # MFA endpoints -from controllers.console.auth.mfa import MFASetupInitApi, MFASetupCompleteApi, MFADisableApi, MFAStatusApi +from controllers.console.auth.mfa import MFADisableApi, MFASetupCompleteApi, MFASetupInitApi, MFAStatusApi + api.add_resource(MFAStatusApi, "/account/mfa/status") api.add_resource(MFASetupInitApi, "/account/mfa/setup") api.add_resource(MFASetupCompleteApi, "/account/mfa/setup/complete") diff --git a/api/services/mfa_service.py b/api/services/mfa_service.py index 250c9db9be..5f1a1456ca 100644 --- a/api/services/mfa_service.py +++ b/api/services/mfa_service.py @@ -2,13 +2,10 @@ import base64 import io import json import secrets -from datetime import datetime, timezone -from typing import Optional +from datetime import UTC, datetime import pyotp import qrcode -from sqlalchemy import and_ -from sqlalchemy.orm import Session from models.account import Account, AccountMFASettings from models.engine import db @@ -53,7 +50,7 @@ class MFAService: # Convert to base64 buffer = io.BytesIO() - img.save(buffer, format='PNG') + img.save(buffer) img_str = base64.b64encode(buffer.getvalue()).decode() return f"data:image/png;base64,{img_str}" @@ -120,7 +117,7 @@ class MFAService: # Enable MFA mfa_settings.enabled = True mfa_settings.backup_codes = json.dumps(backup_codes) - mfa_settings.setup_at = datetime.now(timezone.utc) + mfa_settings.setup_at = datetime.now(UTC) db.session.commit() @@ -176,7 +173,7 @@ class MFAService: def is_mfa_required(account: Account) -> bool: """Check if MFA is required for this account.""" mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() - return mfa_settings and mfa_settings.enabled and mfa_settings.secret is not None + return bool(mfa_settings and mfa_settings.enabled and mfa_settings.secret is not None) @staticmethod def authenticate_with_mfa(account: Account, token: str) -> bool: diff --git a/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py b/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py index d18ae7224c..6e1a806467 100644 --- a/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py +++ b/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py @@ -1,14 +1,8 @@ import json import sys -import pytest from unittest.mock import Mock, patch -from datetime import datetime from flask import Flask -from flask_restful import Api - -from controllers.console.auth.login import LoginApi -from models.account import Account, AccountMFASettings class TestLoginMFAIntegration: @@ -23,10 +17,10 @@ class TestLoginMFAIntegration: @patch('controllers.console.auth.login.AccountService.login') @patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') @patch('controllers.console.auth.login.extract_remote_ip') - def test_login_without_mfa_success(self, mock_extract_ip, mock_reset_limit, + def test_login_without_mfa_success(self, mock_extract_ip, mock_reset_limit, mock_login_service, mock_get_tenants, mock_is_mfa_required, - mock_authenticate, mock_rate_limit, mock_freeze_check, - mock_dify_config, mock_system_features, + mock_authenticate, mock_rate_limit, mock_freeze_check, + mock_dify_config, mock_system_features, test_client, setup_account): """Test successful login without MFA enabled.""" # Setup mocks @@ -37,29 +31,29 @@ class TestLoginMFAIntegration: mock_is_mfa_required.return_value = False mock_get_tenants.return_value = [Mock()] # At least one tenant mock_extract_ip.return_value = "127.0.0.1" - + token_pair_mock = Mock() token_pair_mock.model_dump.return_value = { "access_token": "test_access_token", "refresh_token": "test_refresh_token" } mock_login_service.return_value = token_pair_mock - + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: mock_setup.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f - + response = test_client.post('/console/api/login', json={ "email": setup_account.email, "password": "TestPassword123" }) - + # Debug output if response.status_code != 200: print(f"Status: {response.status_code}", file=sys.stderr) print(f"Data: {response.data}", file=sys.stderr) - + assert response.status_code == 200 data = json.loads(response.data) assert data["result"] == "success" @@ -72,7 +66,7 @@ class TestLoginMFAIntegration: @patch('controllers.console.auth.login.AccountService.authenticate') @patch('controllers.console.auth.login.MFAService.is_mfa_required') def test_login_with_mfa_required_no_token(self, mock_is_mfa_required, mock_authenticate, - mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_rate_limit, mock_freeze_check, mock_dify_config, mock_system_features, test_client, setup_account): """Test login returns mfa_required when MFA is enabled but no token provided.""" # Setup mocks @@ -81,17 +75,17 @@ class TestLoginMFAIntegration: mock_rate_limit.return_value = False mock_authenticate.return_value = setup_account mock_is_mfa_required.return_value = True - + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: mock_setup.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f - + response = test_client.post('/console/api/login', json={ "email": "test@example.com", "password": "TestPassword123" }) - + assert response.status_code == 200 data = json.loads(response.data) assert data["result"] == "fail" @@ -105,7 +99,7 @@ class TestLoginMFAIntegration: @patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') def test_login_with_mfa_invalid_token(self, mock_auth_mfa, mock_is_mfa_required, mock_authenticate, - mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_rate_limit, mock_freeze_check, mock_dify_config, mock_system_features, test_client, setup_account): """Test login fails with invalid MFA token.""" # Setup mocks @@ -115,18 +109,18 @@ class TestLoginMFAIntegration: mock_authenticate.return_value = setup_account mock_is_mfa_required.return_value = True mock_auth_mfa.return_value = False # Invalid token - + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: mock_setup.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f - + response = test_client.post('/console/api/login', json={ "email": "test@example.com", "password": "TestPassword123", "mfa_code": "invalid_token" }) - + assert response.status_code == 200 data = json.loads(response.data) assert data["result"] == "fail" @@ -144,10 +138,10 @@ class TestLoginMFAIntegration: @patch('controllers.console.auth.login.AccountService.login') @patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') @patch('controllers.console.auth.login.extract_remote_ip') - def test_login_with_mfa_valid_token_success(self, mock_extract_ip, mock_reset_limit, + def test_login_with_mfa_valid_token_success(self, mock_extract_ip, mock_reset_limit, mock_login_service, mock_get_tenants, mock_auth_mfa, mock_is_mfa_required, mock_authenticate, - mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_rate_limit, mock_freeze_check, mock_dify_config, mock_system_features, test_client, setup_account): """Test successful login with valid MFA token.""" # Setup mocks @@ -159,30 +153,30 @@ class TestLoginMFAIntegration: mock_auth_mfa.return_value = True # Valid token mock_get_tenants.return_value = [Mock()] # At least one tenant mock_extract_ip.return_value = "127.0.0.1" - + token_pair_mock = Mock() token_pair_mock.model_dump.return_value = { "access_token": "test_access_token", "refresh_token": "test_refresh_token" } mock_login_service.return_value = token_pair_mock - + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: mock_setup.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f - + response = test_client.post('/console/api/login', json={ "email": "test@example.com", "password": "TestPassword123", "mfa_code": "123456" }) - + assert response.status_code == 200 data = json.loads(response.data) assert data["result"] == "success" assert "access_token" in data["data"] - + # Verify MFA authentication was called mock_auth_mfa.assert_called_once_with(setup_account, "123456") @@ -197,10 +191,10 @@ class TestLoginMFAIntegration: @patch('controllers.console.auth.login.AccountService.login') @patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') @patch('controllers.console.auth.login.extract_remote_ip') - def test_login_with_mfa_backup_code_success(self, mock_extract_ip, mock_reset_limit, + def test_login_with_mfa_backup_code_success(self, mock_extract_ip, mock_reset_limit, mock_login_service, mock_get_tenants, mock_auth_mfa, mock_is_mfa_required, mock_authenticate, - mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_rate_limit, mock_freeze_check, mock_dify_config, mock_system_features, test_client, setup_account): """Test successful login with valid backup code.""" # Setup mocks @@ -212,30 +206,30 @@ class TestLoginMFAIntegration: mock_auth_mfa.return_value = True # Valid backup code mock_get_tenants.return_value = [Mock()] # At least one tenant mock_extract_ip.return_value = "127.0.0.1" - + token_pair_mock = Mock() token_pair_mock.model_dump.return_value = { "access_token": "test_access_token", "refresh_token": "test_refresh_token" } mock_login_service.return_value = token_pair_mock - + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: mock_setup.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f - + response = test_client.post('/console/api/login', json={ "email": "test@example.com", "password": "TestPassword123", "mfa_code": "BACKUP123" # Backup code format }) - + assert response.status_code == 200 data = json.loads(response.data) assert data["result"] == "success" assert "access_token" in data["data"] - + # Verify MFA authentication was called with backup code mock_auth_mfa.assert_called_once_with(setup_account, "BACKUP123") @@ -246,40 +240,40 @@ class TestLoginMFAIntegration: @patch('controllers.console.auth.login.AccountService.authenticate') @patch('controllers.console.auth.login.MFAService.is_mfa_required') def test_login_mfa_flow_order(self, mock_is_mfa_required, mock_authenticate, - mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_rate_limit, mock_freeze_check, mock_dify_config, mock_system_features, test_client): """Test that MFA check happens after password authentication.""" # Setup mocks - password auth fails mock_dify_config.BILLING_ENABLED = False mock_freeze_check.return_value = False mock_rate_limit.return_value = False - + # Mock password authentication failure from services.errors.account import AccountPasswordError mock_authenticate.side_effect = AccountPasswordError() - + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled, \ patch('controllers.console.auth.login.AccountService.add_login_error_rate_limit') as mock_add_limit: mock_setup.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f - + response = test_client.post('/console/api/login', json={ "email": "test@example.com", "password": "WrongPassword123", "mfa_code": "123456" }) - + # Password error should trigger EmailOrPasswordMismatchError assert response.status_code == 400 - + # MFA check should not be called if password auth fails mock_is_mfa_required.assert_not_called() class TestMFAEndToEndFlow: """End-to-end tests for complete MFA flow.""" - + def setup_method(self): self.app = Flask(__name__) self.app.config['TESTING'] = True @@ -292,40 +286,40 @@ class TestMFAEndToEndFlow: @patch('services.mfa_service.db.session') def test_complete_mfa_setup_flow(self, mock_session, mock_gen_codes, mock_verify, mock_gen_qr, mock_gen_secret): """Test complete MFA setup flow from init to completion.""" - from services.mfa_service import MFAService from models.account import Account - + from services.mfa_service import MFAService + # Mock account account = Mock(spec=Account) account.id = "test-id" account.email = "test@example.com" - + # Setup mocks mock_gen_secret.return_value = "TESTSECRET123" mock_gen_qr.return_value = "data:image/png;base64,test" mock_verify.return_value = True mock_gen_codes.return_value = ["CODE1", "CODE2", "CODE3"] - + # Step 1: Initialize MFA setup with patch('services.mfa_service.MFAService.get_or_create_mfa_settings') as mock_get_settings: mfa_settings = Mock() mfa_settings.enabled = False mfa_settings.secret = None mock_get_settings.return_value = mfa_settings - + setup_data = MFAService.generate_mfa_setup_data(account) - + assert setup_data["secret"] == "TESTSECRET123" assert setup_data["qr_code"] == "data:image/png;base64,test" assert mfa_settings.secret == "TESTSECRET123" - + # Step 2: Complete MFA setup with patch('services.mfa_service.MFAService.get_or_create_mfa_settings') as mock_get_settings: mfa_settings.secret = "TESTSECRET123" mock_get_settings.return_value = mfa_settings - + result = MFAService.setup_mfa(account, "123456") - + assert mfa_settings.enabled is True assert result["backup_codes"] == ["CODE1", "CODE2", "CODE3"] - assert mfa_settings.setup_at is not None \ No newline at end of file + assert mfa_settings.setup_at is not None diff --git a/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py b/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py index 1fe32d135a..b5ab8fe255 100644 --- a/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py +++ b/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py @@ -1,7 +1,8 @@ -import pytest -from datetime import datetime, timezone +from datetime import UTC, datetime from unittest.mock import patch +import pytest + from services.account_service import AccountService from services.mfa_service import MFAService @@ -71,7 +72,7 @@ class TestMFAEndpoints: with patch.object(MFAService, 'setup_mfa') as mock_setup: mock_setup.return_value = { "backup_codes": ["CODE1", "CODE2", "CODE3", "CODE4", "CODE5", "CODE6", "CODE7", "CODE8"], - "setup_at": datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + "setup_at": datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) } response = test_client.post( @@ -97,7 +98,8 @@ class TestMFAEndpoints: assert response.status_code == 400 data = response.json - assert "message" in data and "TOTP token is required" in data["message"] + assert "message" in data + assert "TOTP token is required" in data["message"] def test_mfa_setup_complete_invalid_token(self, test_client, setup_account, auth_header): """Test MFA setup completion with invalid token.""" diff --git a/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py b/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py index 327cf908b0..a49d4c49cf 100644 --- a/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py +++ b/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py @@ -1,7 +1,5 @@ -import json from unittest import mock -from models.account import Account from services.mfa_service import MFAService @@ -12,7 +10,7 @@ class TestMFASimpleIntegration: """Test MFA setup flow end-to-end.""" # Step 1: Check initial MFA status response = test_client.get( - f"/console/api/account/mfa/status", + "/console/api/account/mfa/status", headers=auth_header ) assert response.status_code == 200 @@ -21,7 +19,7 @@ class TestMFASimpleIntegration: # Step 2: Initialize MFA setup response = test_client.post( - f"/console/api/account/mfa/setup", + "/console/api/account/mfa/setup", headers=auth_header ) assert response.status_code == 200 @@ -33,7 +31,7 @@ class TestMFASimpleIntegration: # Step 3: Complete MFA setup with mocked TOTP with mock.patch.object(MFAService, 'verify_totp', return_value=True): response = test_client.post( - f"/console/api/account/mfa/setup/complete", + "/console/api/account/mfa/setup/complete", headers=auth_header, json={"totp_token": "123456"} ) @@ -44,7 +42,7 @@ class TestMFASimpleIntegration: # Step 4: Verify MFA is now enabled response = test_client.get( - f"/console/api/account/mfa/status", + "/console/api/account/mfa/status", headers=auth_header ) assert response.status_code == 200 @@ -55,7 +53,7 @@ class TestMFASimpleIntegration: """Test MFA disable flow.""" # First check MFA status and disable if already enabled response = test_client.get( - f"/console/api/account/mfa/status", + "/console/api/account/mfa/status", headers=auth_header ) assert response.status_code == 200 @@ -65,7 +63,7 @@ class TestMFASimpleIntegration: # MFA is already enabled, disable it first with mocked password verification with mock.patch('libs.password.compare_password', return_value=True): response = test_client.post( - f"/console/api/account/mfa/disable", + "/console/api/account/mfa/disable", headers=auth_header, json={"password": "any_password"} # Password doesn't matter, it's mocked ) @@ -75,14 +73,14 @@ class TestMFASimpleIntegration: with mock.patch.object(MFAService, 'verify_totp', return_value=True): # Initialize setup response = test_client.post( - f"/console/api/account/mfa/setup", + "/console/api/account/mfa/setup", headers=auth_header ) assert response.status_code == 200 # Complete setup response = test_client.post( - f"/console/api/account/mfa/setup/complete", + "/console/api/account/mfa/setup/complete", headers=auth_header, json={"totp_token": "123456"} ) @@ -91,7 +89,7 @@ class TestMFASimpleIntegration: # Now disable MFA with mocked password verification with mock.patch('libs.password.compare_password', return_value=True): response = test_client.post( - f"/console/api/account/mfa/disable", + "/console/api/account/mfa/disable", headers=auth_header, json={"password": "any_password"} # Password doesn't matter, it's mocked ) @@ -101,7 +99,7 @@ class TestMFASimpleIntegration: # Verify MFA is disabled response = test_client.get( - f"/console/api/account/mfa/status", + "/console/api/account/mfa/status", headers=auth_header ) assert response.status_code == 200 diff --git a/api/tests/unit_tests/services/test_mfa_service.py b/api/tests/unit_tests/services/test_mfa_service.py index a820b8c93e..7d77462584 100644 --- a/api/tests/unit_tests/services/test_mfa_service.py +++ b/api/tests/unit_tests/services/test_mfa_service.py @@ -1,9 +1,7 @@ import json import unittest +from datetime import datetime from unittest.mock import Mock, patch -from datetime import datetime, timezone - -import pytest from models.account import Account, AccountMFASettings from services.mfa_service import MFAService