diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index b40934dbf5..c222967b67 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -65,3 +65,27 @@ class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" description = "Too many failed password reset attempts. Please try again in 24 hours." code = 429 + + +class MFARequiredError(BaseHTTPException): + error_code = "mfa_required" + description = "Multi-factor authentication is required." + code = 401 + + +class MFATokenRequiredError(BaseHTTPException): + error_code = "mfa_token_invalid" + description = "The MFA token is invalid or expired." + code = 401 + + +class MFASetupRequiredError(BaseHTTPException): + error_code = "mfa_setup_required" + description = "MFA setup is required to complete this action." + code = 400 + + +class TokenValidationError(BaseHTTPException): + error_code = "token_validation_error" + description = "Token validation failed." + code = 400 diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 5f2a24322d..aec21c1bb3 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -14,6 +14,8 @@ from controllers.console.auth.error import ( EmailPasswordLoginLimitError, InvalidEmailError, InvalidTokenError, + MFARequiredError, + MFATokenRequiredError, ) from controllers.console.error import ( AccountBannedError, @@ -33,6 +35,7 @@ from services.billing_service import BillingService from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService +from services.mfa_service import MFAService class LoginApi(Resource): @@ -48,6 +51,8 @@ class LoginApi(Resource): parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument("invite_token", type=str, required=False, default=None, location="json") parser.add_argument("language", type=str, required=False, default="en-US", location="json") + parser.add_argument("mfa_code", type=str, required=False, default=None, location="json") + parser.add_argument("is_backup_code", type=bool, required=False, default=False, location="json") args = parser.parse_args() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): @@ -86,6 +91,15 @@ class LoginApi(Resource): return {"result": "fail", "data": token, "code": "account_not_found"} else: raise AccountNotFound() + + # Check MFA requirement + if MFAService.is_mfa_required(account): + if not args["mfa_code"]: + return {"result": "fail", "code": "mfa_required"} + + if not MFAService.authenticate_with_mfa(account, args["mfa_code"]): + return {"result": "fail", "code": "mfa_token_invalid", "data": "The MFA token is invalid or expired."} + # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: diff --git a/api/controllers/console/auth/mfa.py b/api/controllers/console/auth/mfa.py new file mode 100644 index 0000000000..4c3f76eac1 --- /dev/null +++ b/api/controllers/console/auth/mfa.py @@ -0,0 +1,128 @@ +from typing import cast + +import flask_login +from flask import request +from flask_restful import Resource, reqparse + +from controllers.console.auth.error import ( + TokenValidationError, +) +from controllers.console.wraps import account_initialization_required +from libs.login import login_required +from models.account import Account +from services.mfa_service import MFAService + + +class MFASetupInitApi(Resource): + @login_required + @account_initialization_required + def get(self): + """Initialize MFA setup - generate secret and QR code (GET method for compatibility).""" + return self.post() + + @login_required + @account_initialization_required + def post(self): + """Initialize MFA setup - generate secret and QR code.""" + account = cast(Account, flask_login.current_user) + + try: + mfa_status = MFAService.get_mfa_status(account) + if mfa_status["enabled"]: + return {"error": "MFA is already enabled"}, 400 + + setup_data = MFAService.generate_mfa_setup_data(account) + return { + "secret": setup_data["secret"], + "qr_code": setup_data["qr_code"] + } + except Exception as e: + return {"error": str(e)}, 500 + + +class MFASetupCompleteApi(Resource): + @login_required + @account_initialization_required + def post(self): + """Complete MFA setup with TOTP verification.""" + parser = reqparse.RequestParser() + parser.add_argument("totp_token", type=str, required=True, help="TOTP token is required") + args = parser.parse_args() + + account = cast(Account, flask_login.current_user) + + try: + result = MFAService.setup_mfa(account, args["totp_token"]) + return { + "message": "MFA setup completed successfully", + "backup_codes": result["backup_codes"], + "setup_at": result["setup_at"].isoformat() + } + except ValueError as e: + return {"error": str(e)}, 400 + except Exception as e: + return {"error": str(e)}, 500 + + +class MFADisableApi(Resource): + @login_required + @account_initialization_required + def post(self): + """Disable MFA with password verification.""" + parser = reqparse.RequestParser() + parser.add_argument("password", type=str, required=True, help="Password is required") + args = parser.parse_args() + + account = cast(Account, flask_login.current_user) + + try: + mfa_status = MFAService.get_mfa_status(account) + if not mfa_status["enabled"]: + return {"error": "MFA is not enabled"}, 400 + + if MFAService.disable_mfa(account, args["password"]): + return {"message": "MFA disabled successfully"} + else: + return {"error": "Invalid password"}, 400 + except Exception as e: + return {"error": str(e)}, 500 + + +class MFAStatusApi(Resource): + @login_required + @account_initialization_required + def get(self): + """Get current MFA status.""" + account = cast(Account, flask_login.current_user) + + try: + status = MFAService.get_mfa_status(account) + return status + except Exception as e: + return {"error": str(e)}, 500 + + +class MFAVerifyApi(Resource): + def post(self): + """Verify MFA token during login (public endpoint).""" + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, help="Email is required") + parser.add_argument("mfa_token", type=str, required=True, help="MFA token is required") + args = parser.parse_args() + + from models.engine import db + account = db.session.query(Account).filter_by(email=args["email"]).first() + + if not account: + return {"error": "Account not found"}, 404 + + if not MFAService.is_mfa_required(account): + return {"error": "MFA not required for this account"}, 400 + + try: + if MFAService.authenticate_with_mfa(account, args["mfa_token"]): + return {"message": "MFA verification successful"} + else: + return {"error": "Invalid MFA token"}, 400 + except Exception as e: + return {"error": str(e)}, 500 \ No newline at end of file diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index a9dbf44456..9c79782f47 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -387,3 +387,10 @@ api.add_resource(EducationApi, "/account/education") api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') + +# MFA endpoints +from controllers.console.auth.mfa import MFASetupInitApi, MFASetupCompleteApi, MFADisableApi, MFAStatusApi +api.add_resource(MFAStatusApi, "/account/mfa/status") +api.add_resource(MFASetupInitApi, "/account/mfa/setup") +api.add_resource(MFASetupCompleteApi, "/account/mfa/setup/complete") +api.add_resource(MFADisableApi, "/account/mfa/disable") diff --git a/api/migrations/versions/2025_07_13_0900-xyz789abc123_add_account_mfa_settings.py b/api/migrations/versions/2025_07_13_0900-xyz789abc123_add_account_mfa_settings.py new file mode 100644 index 0000000000..06d72690a4 --- /dev/null +++ b/api/migrations/versions/2025_07_13_0900-xyz789abc123_add_account_mfa_settings.py @@ -0,0 +1,43 @@ +"""add account mfa settings table + +Revision ID: xyz789abc123 +Revises: 58eb7bdb93fe +Create Date: 2025-07-13 09:00:00.000000 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'xyz789abc123' +down_revision = '58eb7bdb93fe' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('account_mfa_settings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('secret', sa.String(length=255), nullable=True), + sa.Column('backup_codes', sa.Text(), nullable=True), + sa.Column('setup_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ), + sa.PrimaryKeyConstraint('id', name='account_mfa_settings_pkey'), + sa.UniqueConstraint('account_id', name='unique_account_mfa_settings') + ) + op.create_index('account_mfa_settings_account_id_idx', 'account_mfa_settings', ['account_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('account_mfa_settings_account_id_idx', table_name='account_mfa_settings') + op.drop_table('account_mfa_settings') + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/models/account.py b/api/models/account.py index 7ffeefa980..c8c4abe175 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -299,3 +299,24 @@ class TenantPluginPermission(Base): db.String(16), nullable=False, server_default="everyone" ) debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + + +class AccountMFASettings(Base): + __tablename__ = "account_mfa_settings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="account_mfa_settings_pkey"), + db.UniqueConstraint("account_id", name="unique_account_mfa_settings"), + db.Index("account_mfa_settings_account_id_idx", "account_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + account_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("accounts.id"), nullable=False) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + secret = db.Column(db.String(255), nullable=True) + backup_codes = db.Column(db.Text, nullable=True) + setup_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + # Relationship + account = db.relationship("Account", backref=db.backref("mfa_settings", uselist=False, cascade="all, delete-orphan")) diff --git a/api/pyproject.toml b/api/pyproject.toml index 7f1efa671f..3fcbcb9862 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -66,10 +66,12 @@ dependencies = [ "pydantic-extra-types~=2.10.3", "pydantic-settings~=2.9.1", "pyjwt~=2.8.0", + "pyotp~=2.9.0", "pypdfium2==4.30.0", "python-docx~=1.1.0", "python-dotenv==1.0.1", "pyyaml~=6.0.1", + "qrcode[pil]~=7.4.2", "readabilipy~=0.3.0", "redis[hiredis]~=6.1.0", "resend~=2.9.0", diff --git a/api/services/mfa_service.py b/api/services/mfa_service.py new file mode 100644 index 0000000000..250c9db9be --- /dev/null +++ b/api/services/mfa_service.py @@ -0,0 +1,224 @@ +import base64 +import io +import json +import secrets +from datetime import datetime, timezone +from typing import Optional + +import pyotp +import qrcode +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from models.account import Account, AccountMFASettings +from models.engine import db + + +class MFAService: + @staticmethod + def generate_secret() -> str: + """Generate a new TOTP secret for the user.""" + return pyotp.random_base32() + + @staticmethod + def generate_backup_codes(count: int = 8) -> list[str]: + """Generate backup codes for account recovery.""" + codes = [] + for _ in range(count): + code = secrets.token_hex(4).upper() + codes.append(code) + return codes + + @staticmethod + def generate_qr_code(account: Account, secret: str) -> str: + """Generate QR code for TOTP setup.""" + totp = pyotp.TOTP(secret) + provisioning_uri = totp.provisioning_uri( + name=account.email, + issuer_name="Dify" + ) + + # Generate QR code + qr = qrcode.QRCode( + version=1, + error_correction=qrcode.constants.ERROR_CORRECT_L, + box_size=10, + border=4, + ) + qr.add_data(provisioning_uri) + qr.make(fit=True) + + # Create image + img = qr.make_image(fill_color="black", back_color="white") + + # Convert to base64 + buffer = io.BytesIO() + img.save(buffer, format='PNG') + img_str = base64.b64encode(buffer.getvalue()).decode() + + return f"data:image/png;base64,{img_str}" + + @staticmethod + def verify_totp(secret: str, token: str) -> bool: + """Verify TOTP token.""" + if not secret: + return False + try: + totp = pyotp.TOTP(secret) + return totp.verify(token, valid_window=1) + except Exception as e: + print(f"[MFA DEBUG] verify_totp error: {type(e).__name__}: {str(e)}") + return False + + @staticmethod + def get_or_create_mfa_settings(account: Account) -> AccountMFASettings: + """Get or create MFA settings for account.""" + mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() + if not mfa_settings: + mfa_settings = AccountMFASettings(account_id=account.id) + db.session.add(mfa_settings) + db.session.commit() + return mfa_settings + + @staticmethod + def verify_backup_code(mfa_settings: AccountMFASettings, code: str) -> bool: + """Verify and consume backup code.""" + if not mfa_settings.backup_codes: + return False + + try: + backup_codes = json.loads(mfa_settings.backup_codes) + if code.upper() in backup_codes: + # Remove used backup code + backup_codes.remove(code.upper()) + mfa_settings.backup_codes = json.dumps(backup_codes) + db.session.commit() + return True + except json.JSONDecodeError: + pass + + return False + + @staticmethod + def setup_mfa(account: Account, totp_token: str) -> dict: + """Setup MFA for account with TOTP verification.""" + mfa_settings = MFAService.get_or_create_mfa_settings(account) + + if mfa_settings.enabled: + raise ValueError("MFA is already enabled for this account") + + if not mfa_settings.secret: + raise ValueError("MFA secret not generated") + + # Verify TOTP token + if not MFAService.verify_totp(mfa_settings.secret, totp_token): + raise ValueError("Invalid TOTP token") + + # Generate backup codes + backup_codes = MFAService.generate_backup_codes() + + # Enable MFA + mfa_settings.enabled = True + mfa_settings.backup_codes = json.dumps(backup_codes) + mfa_settings.setup_at = datetime.now(timezone.utc) + + db.session.commit() + + return { + "backup_codes": backup_codes, + "setup_at": mfa_settings.setup_at + } + + @staticmethod + def disable_mfa(account: Account, password: str) -> bool: + """Disable MFA for account after password verification.""" + from libs.password import compare_password + + # Verify password + if account.password is None or not compare_password(password, account.password, account.password_salt): + return False + + mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() + if not mfa_settings: + return True # Already disabled + + # Disable MFA + mfa_settings.enabled = False + mfa_settings.secret = None + mfa_settings.backup_codes = None + mfa_settings.setup_at = None + + db.session.commit() + return True + + @staticmethod + def generate_mfa_setup_data(account: Account) -> dict: + """Generate MFA setup data including secret and QR code.""" + mfa_settings = MFAService.get_or_create_mfa_settings(account) + + if mfa_settings.enabled: + raise ValueError("MFA is already enabled for this account") + + # Generate new secret + secret = MFAService.generate_secret() + mfa_settings.secret = secret + db.session.commit() + + # Generate QR code + qr_code = MFAService.generate_qr_code(account, secret) + + return { + "secret": secret, + "qr_code": qr_code + } + + @staticmethod + def is_mfa_required(account: Account) -> bool: + """Check if MFA is required for this account.""" + mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() + return mfa_settings and mfa_settings.enabled and mfa_settings.secret is not None + + @staticmethod + def authenticate_with_mfa(account: Account, token: str) -> bool: + """Authenticate user with MFA token (TOTP or backup code).""" + print(f"[MFA DEBUG] authenticate_with_mfa called with token: {token}") + mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() + + if not mfa_settings or not mfa_settings.enabled: + print("[MFA DEBUG] MFA not enabled, returning True") + return True + + print(f"[MFA DEBUG] MFA enabled, secret: {mfa_settings.secret[:10]}...") + + # Try TOTP first + print("[MFA DEBUG] Trying TOTP verification") + if MFAService.verify_totp(mfa_settings.secret, token): + print("[MFA DEBUG] TOTP verification successful") + return True + + # Try backup code + print("[MFA DEBUG] Trying backup code verification") + if MFAService.verify_backup_code(mfa_settings, token): + print("[MFA DEBUG] Backup code verification successful") + return True + + print("[MFA DEBUG] All verifications failed") + return False + + @staticmethod + def get_mfa_status(account: Account) -> dict: + """Get MFA status for account.""" + mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() + + if not mfa_settings: + return { + "enabled": False, + "setup_at": None, + "has_backup_codes": False + } + + return { + "enabled": mfa_settings.enabled, + "setup_at": mfa_settings.setup_at.isoformat() if mfa_settings.setup_at else None, + "has_backup_codes": mfa_settings.backup_codes is not None + } \ No newline at end of file diff --git a/api/tests/integration_tests/controllers/__init__.py b/api/tests/integration_tests/controllers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/controllers/console/auth/__init__.py b/api/tests/integration_tests/controllers/console/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py b/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py new file mode 100644 index 0000000000..e13bd51592 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/auth/test_login_mfa_integration.py @@ -0,0 +1,296 @@ +import json +import pytest +from unittest.mock import Mock, patch +from datetime import datetime + +from flask import Flask +from flask_restful import Api + +from controllers.console.auth.login import LoginApi +from models.account import Account, AccountMFASettings + + +class TestLoginMFAIntegration: + + def test_login_without_mfa_success(self, test_client, setup_account): + """Test successful login without MFA enabled.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + mock_auth.return_value = setup_account + + with patch('services.mfa_service.MFAService.is_mfa_required') as mock_mfa: + mock_mfa.return_value = False + + response = test_client.post('/console/api/login', json={ + "email": setup_account.email, + "password": "test_password" + }) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "success" + assert "access_token" in data["data"] + + @patch('controllers.console.auth.login.FeatureService.get_system_features') + @patch('controllers.console.auth.login.dify_config') + @patch('controllers.console.auth.login.BillingService.is_email_in_freeze') + @patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') + @patch('controllers.console.auth.login.AccountService.authenticate') + @patch('controllers.console.auth.login.MFAService.is_mfa_required') + def test_login_with_mfa_required_no_token(self, mock_is_mfa_required, mock_authenticate, + mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_system_features, test_client, setup_account): + """Test login returns mfa_required when MFA is enabled but no token provided.""" + # Setup mocks + mock_dify_config.BILLING_ENABLED = False + mock_freeze_check.return_value = False + mock_rate_limit.return_value = False + mock_authenticate.return_value = setup_account + mock_is_mfa_required.return_value = True + + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ + patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: + mock_setup.return_value = lambda f: f + mock_email_enabled.return_value = lambda f: f + + response = test_client.post('/console/api/login', json={ + "email": "test@example.com", + "password": "test_password" + }) + + assert response.status_code == 200 + data = json.loads(response.data) + assert data["result"] == "fail" + assert data["code"] == "mfa_required" + + @patch('controllers.console.auth.login.FeatureService.get_system_features') + @patch('controllers.console.auth.login.dify_config') + @patch('controllers.console.auth.login.BillingService.is_email_in_freeze') + @patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') + @patch('controllers.console.auth.login.AccountService.authenticate') + @patch('controllers.console.auth.login.MFAService.is_mfa_required') + @patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') + def test_login_with_mfa_invalid_token(self, mock_auth_mfa, mock_is_mfa_required, mock_authenticate, + mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_system_features, test_client, setup_account): + """Test login fails with invalid MFA token.""" + # Setup mocks + mock_dify_config.BILLING_ENABLED = False + mock_freeze_check.return_value = False + mock_rate_limit.return_value = False + mock_authenticate.return_value = setup_account + mock_is_mfa_required.return_value = True + mock_auth_mfa.return_value = False # Invalid token + + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ + patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: + mock_setup.return_value = lambda f: f + mock_email_enabled.return_value = lambda f: f + + response = test_client.post('/console/api/login', json={ + "email": "test@example.com", + "password": "test_password", + "mfa_code": "invalid_token" + }) + + assert response.status_code == 200 + data = json.loads(response.data) + assert data["result"] == "fail" + assert data["code"] == "mfa_token_invalid" + assert data["data"] == "The MFA token is invalid or expired." + + @patch('controllers.console.auth.login.FeatureService.get_system_features') + @patch('controllers.console.auth.login.dify_config') + @patch('controllers.console.auth.login.BillingService.is_email_in_freeze') + @patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') + @patch('controllers.console.auth.login.AccountService.authenticate') + @patch('controllers.console.auth.login.MFAService.is_mfa_required') + @patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') + @patch('controllers.console.auth.login.TenantService.get_join_tenants') + @patch('controllers.console.auth.login.AccountService.login') + @patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') + @patch('controllers.console.auth.login.extract_remote_ip') + def test_login_with_mfa_valid_token_success(self, mock_extract_ip, mock_reset_limit, + mock_login_service, mock_get_tenants, mock_auth_mfa, + mock_is_mfa_required, mock_authenticate, + mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_system_features, test_client, setup_account): + """Test successful login with valid MFA token.""" + # Setup mocks + mock_dify_config.BILLING_ENABLED = False + mock_freeze_check.return_value = False + mock_rate_limit.return_value = False + mock_authenticate.return_value = setup_account + mock_is_mfa_required.return_value = True + mock_auth_mfa.return_value = True # Valid token + mock_get_tenants.return_value = [Mock()] # At least one tenant + mock_extract_ip.return_value = "127.0.0.1" + + token_pair_mock = Mock() + token_pair_mock.model_dump.return_value = { + "access_token": "test_access_token", + "refresh_token": "test_refresh_token" + } + mock_login_service.return_value = token_pair_mock + + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ + patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: + mock_setup.return_value = lambda f: f + mock_email_enabled.return_value = lambda f: f + + response = test_client.post('/console/api/login', json={ + "email": "test@example.com", + "password": "test_password", + "mfa_code": "123456" + }) + + assert response.status_code == 200 + data = json.loads(response.data) + assert data["result"] == "success" + assert "access_token" in data["data"] + + # Verify MFA authentication was called + mock_auth_mfa.assert_called_once_with(setup_account, "123456") + + @patch('controllers.console.auth.login.FeatureService.get_system_features') + @patch('controllers.console.auth.login.dify_config') + @patch('controllers.console.auth.login.BillingService.is_email_in_freeze') + @patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') + @patch('controllers.console.auth.login.AccountService.authenticate') + @patch('controllers.console.auth.login.MFAService.is_mfa_required') + @patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') + @patch('controllers.console.auth.login.TenantService.get_join_tenants') + @patch('controllers.console.auth.login.AccountService.login') + @patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') + @patch('controllers.console.auth.login.extract_remote_ip') + def test_login_with_mfa_backup_code_success(self, mock_extract_ip, mock_reset_limit, + mock_login_service, mock_get_tenants, mock_auth_mfa, + mock_is_mfa_required, mock_authenticate, + mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_system_features, test_client, setup_account): + """Test successful login with valid backup code.""" + # Setup mocks + mock_dify_config.BILLING_ENABLED = False + mock_freeze_check.return_value = False + mock_rate_limit.return_value = False + mock_authenticate.return_value = setup_account + mock_is_mfa_required.return_value = True + mock_auth_mfa.return_value = True # Valid backup code + mock_get_tenants.return_value = [Mock()] # At least one tenant + mock_extract_ip.return_value = "127.0.0.1" + + token_pair_mock = Mock() + token_pair_mock.model_dump.return_value = { + "access_token": "test_access_token", + "refresh_token": "test_refresh_token" + } + mock_login_service.return_value = token_pair_mock + + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ + patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: + mock_setup.return_value = lambda f: f + mock_email_enabled.return_value = lambda f: f + + response = test_client.post('/console/api/login', json={ + "email": "test@example.com", + "password": "test_password", + "mfa_code": "BACKUP123" # Backup code format + }) + + assert response.status_code == 200 + data = json.loads(response.data) + assert data["result"] == "success" + assert "access_token" in data["data"] + + # Verify MFA authentication was called with backup code + mock_auth_mfa.assert_called_once_with(setup_account, "BACKUP123") + + @patch('controllers.console.auth.login.FeatureService.get_system_features') + @patch('controllers.console.auth.login.dify_config') + @patch('controllers.console.auth.login.BillingService.is_email_in_freeze') + @patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') + @patch('controllers.console.auth.login.AccountService.authenticate') + @patch('controllers.console.auth.login.MFAService.is_mfa_required') + def test_login_mfa_flow_order(self, mock_is_mfa_required, mock_authenticate, + mock_rate_limit, mock_freeze_check, mock_dify_config, + mock_system_features, test_client): + """Test that MFA check happens after password authentication.""" + # Setup mocks - password auth fails + mock_dify_config.BILLING_ENABLED = False + mock_freeze_check.return_value = False + mock_rate_limit.return_value = False + + # Mock password authentication failure + from services.errors.account import AccountPasswordError + mock_authenticate.side_effect = AccountPasswordError() + + with patch('controllers.console.auth.login.setup_required') as mock_setup, \ + patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled, \ + patch('controllers.console.auth.login.AccountService.add_login_error_rate_limit') as mock_add_limit: + mock_setup.return_value = lambda f: f + mock_email_enabled.return_value = lambda f: f + + response = test_client.post('/console/api/login', json={ + "email": "test@example.com", + "password": "wrong_password", + "mfa_code": "123456" + }) + + # Password error should trigger EmailOrPasswordMismatchError + assert response.status_code == 400 + + # MFA check should not be called if password auth fails + mock_is_mfa_required.assert_not_called() + + +class TestMFAEndToEndFlow: + """End-to-end tests for complete MFA flow.""" + + def setup_method(self): + self.app = Flask(__name__) + self.app.config['TESTING'] = True + self.client = self.app.test_client() + + @patch('services.mfa_service.MFAService.generate_secret') + @patch('services.mfa_service.MFAService.generate_qr_code') + @patch('services.mfa_service.MFAService.verify_totp') + @patch('services.mfa_service.MFAService.generate_backup_codes') + @patch('services.mfa_service.db.session') + def test_complete_mfa_setup_flow(self, mock_session, mock_gen_codes, mock_verify, mock_gen_qr, mock_gen_secret): + """Test complete MFA setup flow from init to completion.""" + from services.mfa_service import MFAService + from models.account import Account + + # Mock account + account = Mock(spec=Account) + account.id = "test-id" + account.email = "test@example.com" + + # Setup mocks + mock_gen_secret.return_value = "TESTSECRET123" + mock_gen_qr.return_value = "data:image/png;base64,test" + mock_verify.return_value = True + mock_gen_codes.return_value = ["CODE1", "CODE2", "CODE3"] + + # Step 1: Initialize MFA setup + with patch('services.mfa_service.MFAService.get_or_create_mfa_settings') as mock_get_settings: + mfa_settings = Mock() + mfa_settings.enabled = False + mfa_settings.secret = None + mock_get_settings.return_value = mfa_settings + + setup_data = MFAService.generate_mfa_setup_data(account) + + assert setup_data["secret"] == "TESTSECRET123" + assert setup_data["qr_code"] == "data:image/png;base64,test" + assert mfa_settings.secret == "TESTSECRET123" + + # Step 2: Complete MFA setup + with patch('services.mfa_service.MFAService.get_or_create_mfa_settings') as mock_get_settings: + mfa_settings.secret = "TESTSECRET123" + mock_get_settings.return_value = mfa_settings + + result = MFAService.setup_mfa(account, "123456") + + assert mfa_settings.enabled is True + assert result["backup_codes"] == ["CODE1", "CODE2", "CODE3"] + assert mfa_settings.setup_at is not None \ No newline at end of file diff --git a/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py b/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py new file mode 100644 index 0000000000..97fa789af0 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/auth/test_mfa_endpoints.py @@ -0,0 +1,266 @@ +import pytest +from unittest.mock import patch + +from services.account_service import AccountService +from services.mfa_service import MFAService + + +class TestMFAEndpoints: + """Test MFA endpoints using integration test approach.""" + + @pytest.fixture + def auth_header(self, setup_account): + """Get authentication header with JWT token.""" + token = AccountService.get_account_jwt_token(setup_account) + return {"Authorization": f"Bearer {token}"} + + def test_mfa_status_success(self, test_client, setup_account, auth_header): + """Test successful MFA status check.""" + with patch.object(MFAService, 'get_mfa_status') as mock_status: + mock_status.return_value = {"enabled": False, "setup_at": None} + + response = test_client.get( + '/console/api/account/mfa/status', + headers=auth_header + ) + + assert response.status_code == 200 + data = response.json + assert data["enabled"] is False + assert data["setup_at"] is None + mock_status.assert_called_once_with(setup_account) + + def test_mfa_setup_init_success(self, test_client, setup_account, auth_header): + """Test successful MFA setup initialization.""" + with patch.object(MFAService, 'get_mfa_status') as mock_status: + with patch.object(MFAService, 'generate_mfa_setup_data') as mock_generate: + mock_status.return_value = {"enabled": False} + mock_generate.return_value = { + "secret": "TEST_SECRET", + "qr_code": "data:image/png;base64,test" + } + + response = test_client.post( + '/console/api/account/mfa/setup', + headers=auth_header + ) + + assert response.status_code == 200 + data = response.json + assert data["secret"] == "TEST_SECRET" + assert data["qr_code"] == "data:image/png;base64,test" + mock_generate.assert_called_once_with(setup_account) + + def test_mfa_setup_init_already_enabled(self, test_client, setup_account, auth_header): + """Test MFA setup initialization when already enabled.""" + with patch.object(MFAService, 'get_mfa_status') as mock_status: + mock_status.return_value = {"enabled": True, "setup_at": "2024-01-01T00:00:00"} + + response = test_client.post( + '/console/api/account/mfa/setup', + headers=auth_header + ) + + assert response.status_code == 400 + data = response.json + assert data["error"] == "MFA is already enabled" + + def test_mfa_setup_complete_success(self, test_client, setup_account, auth_header): + """Test successful MFA setup completion.""" + with patch.object(MFAService, 'setup_mfa') as mock_setup: + mock_setup.return_value = { + "message": "MFA has been successfully enabled", + "backup_codes": ["CODE1", "CODE2", "CODE3", "CODE4", "CODE5", "CODE6", "CODE7", "CODE8"], + "setup_at": "2024-01-01T00:00:00" + } + + response = test_client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_header, + json={"totp_token": "123456"} + ) + + assert response.status_code == 200 + data = response.json + assert data["message"] == "MFA has been successfully enabled" + assert len(data["backup_codes"]) == 8 + mock_setup.assert_called_once_with(setup_account, "123456") + + def test_mfa_setup_complete_missing_token(self, test_client, setup_account, auth_header): + """Test MFA setup completion with missing token.""" + response = test_client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_header, + json={} + ) + + assert response.status_code == 400 + data = response.json + assert "totp_token is required" in data["error"] + + def test_mfa_setup_complete_invalid_token(self, test_client, setup_account, auth_header): + """Test MFA setup completion with invalid token.""" + with patch.object(MFAService, 'setup_mfa') as mock_setup: + mock_setup.side_effect = ValueError("Invalid TOTP token") + + response = test_client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_header, + json={"totp_token": "999999"} + ) + + assert response.status_code == 400 + data = response.json + assert "Invalid TOTP token" in data["error"] + + def test_mfa_disable_success(self, test_client, setup_account, auth_header): + """Test successful MFA disable.""" + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.return_value = {"message": "MFA has been disabled"} + + response = test_client.post( + '/console/api/account/mfa/disable', + headers=auth_header, + json={"password": "test_password"} + ) + + assert response.status_code == 200 + data = response.json + assert data["message"] == "MFA has been disabled" + mock_disable.assert_called_once_with(setup_account, "test_password") + + def test_mfa_disable_wrong_password(self, test_client, setup_account, auth_header): + """Test MFA disable with wrong password.""" + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.side_effect = ValueError("Invalid password") + + response = test_client.post( + '/console/api/account/mfa/disable', + headers=auth_header, + json={"password": "wrong_password"} + ) + + assert response.status_code == 400 + data = response.json + assert "Invalid password" in data["error"] + + def test_mfa_disable_not_enabled(self, test_client, setup_account, auth_header): + """Test MFA disable when not enabled.""" + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.side_effect = ValueError("MFA is not enabled") + + response = test_client.post( + '/console/api/account/mfa/disable', + headers=auth_header, + json={"password": "test_password"} + ) + + assert response.status_code == 400 + data = response.json + assert "MFA is not enabled" in data["error"] + + def test_mfa_verify_success(self, test_client): + """Test successful MFA verification during login.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + with patch.object(MFAService, 'is_mfa_required') as mock_required: + with patch.object(MFAService, 'authenticate_with_mfa') as mock_verify: + # Mock user exists + from models.account import Account + mock_account = Account( + id="test-id", + email="test@example.com", + name="Test User" + ) + mock_auth.return_value = mock_account + mock_required.return_value = True + mock_verify.return_value = True + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "test@example.com", + "mfa_code": "123456", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "success" + + def test_mfa_verify_invalid_token(self, test_client): + """Test MFA verification with invalid token.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + with patch.object(MFAService, 'is_mfa_required') as mock_required: + with patch.object(MFAService, 'authenticate_with_mfa') as mock_verify: + # Mock user exists + from models.account import Account + mock_account = Account( + id="test-id", + email="test@example.com", + name="Test User" + ) + mock_auth.return_value = mock_account + mock_required.return_value = True + mock_verify.return_value = False + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "test@example.com", + "mfa_code": "999999", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "fail" + assert data["code"] == "mfa_token_invalid" + + def test_mfa_verify_not_required(self, test_client): + """Test MFA verification when MFA is not required.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + with patch.object(MFAService, 'is_mfa_required') as mock_required: + # Mock user exists + from models.account import Account + mock_account = Account( + id="test-id", + email="test@example.com", + name="Test User" + ) + mock_auth.return_value = mock_account + mock_required.return_value = False + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "test@example.com", + "mfa_code": "123456", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "fail" + assert data["code"] == "mfa_not_required" + + def test_mfa_verify_account_not_found(self, test_client): + """Test MFA verification with non-existent account.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + mock_auth.return_value = None + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "nonexistent@example.com", + "mfa_code": "123456", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "fail" + assert data["code"] == "mfa_verify_failed" \ No newline at end of file diff --git a/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py b/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py new file mode 100644 index 0000000000..0fb7d2e936 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/auth/test_mfa_simple.py @@ -0,0 +1,90 @@ +import json +from unittest import mock + +from models.account import Account +from services.mfa_service import MFAService + + +class TestMFASimpleIntegration: + """Simple integration tests for MFA functionality.""" + + def test_mfa_setup_flow(self, test_client, setup_account, auth_header): + """Test MFA setup flow end-to-end.""" + # Step 1: Check initial MFA status + response = test_client.get( + f"/console/api/account/mfa/status", + headers=auth_header + ) + assert response.status_code == 200 + data = response.json + assert data["enabled"] is False + + # Step 2: Initialize MFA setup + response = test_client.post( + f"/console/api/account/mfa/setup", + headers=auth_header + ) + assert response.status_code == 200 + data = response.json + assert "secret" in data + assert "qr_code" in data + secret = data["secret"] + + # Step 3: Complete MFA setup with mocked TOTP + with mock.patch.object(MFAService, 'verify_totp', return_value=True): + response = test_client.post( + f"/console/api/account/mfa/setup/complete", + headers=auth_header, + json={"totp_token": "123456"} + ) + assert response.status_code == 200 + data = response.json + assert "backup_codes" in data + assert len(data["backup_codes"]) == 8 + + # Step 4: Verify MFA is now enabled + response = test_client.get( + f"/console/api/account/mfa/status", + headers=auth_header + ) + assert response.status_code == 200 + data = response.json + assert data["enabled"] is True + + def test_mfa_disable_flow(self, test_client, setup_account, auth_header): + """Test MFA disable flow.""" + # First, set up MFA for the account + with mock.patch.object(MFAService, 'verify_totp', return_value=True): + # Initialize setup + response = test_client.post( + f"/console/api/account/mfa/setup", + headers=auth_header + ) + assert response.status_code == 200 + + # Complete setup + response = test_client.post( + f"/console/api/account/mfa/setup/complete", + headers=auth_header, + json={"totp_token": "123456"} + ) + assert response.status_code == 200 + + # Now disable MFA + response = test_client.post( + f"/console/api/account/mfa/disable", + headers=auth_header, + json={"password": "password"} # Default test password + ) + assert response.status_code == 200 + data = response.json + assert "disabled successfully" in data["message"] + + # Verify MFA is disabled + response = test_client.get( + f"/console/api/account/mfa/status", + headers=auth_header + ) + assert response.status_code == 200 + data = response.json + assert data["enabled"] is False \ No newline at end of file diff --git a/api/tests/unit_tests/controllers/__init__.py b/api/tests/unit_tests/controllers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/__init__.py b/api/tests/unit_tests/controllers/console/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/auth/__init__.py b/api/tests/unit_tests/controllers/console/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_mfa.py b/api/tests/unit_tests/controllers/console/auth/test_mfa.py new file mode 100644 index 0000000000..97fa789af0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_mfa.py @@ -0,0 +1,266 @@ +import pytest +from unittest.mock import patch + +from services.account_service import AccountService +from services.mfa_service import MFAService + + +class TestMFAEndpoints: + """Test MFA endpoints using integration test approach.""" + + @pytest.fixture + def auth_header(self, setup_account): + """Get authentication header with JWT token.""" + token = AccountService.get_account_jwt_token(setup_account) + return {"Authorization": f"Bearer {token}"} + + def test_mfa_status_success(self, test_client, setup_account, auth_header): + """Test successful MFA status check.""" + with patch.object(MFAService, 'get_mfa_status') as mock_status: + mock_status.return_value = {"enabled": False, "setup_at": None} + + response = test_client.get( + '/console/api/account/mfa/status', + headers=auth_header + ) + + assert response.status_code == 200 + data = response.json + assert data["enabled"] is False + assert data["setup_at"] is None + mock_status.assert_called_once_with(setup_account) + + def test_mfa_setup_init_success(self, test_client, setup_account, auth_header): + """Test successful MFA setup initialization.""" + with patch.object(MFAService, 'get_mfa_status') as mock_status: + with patch.object(MFAService, 'generate_mfa_setup_data') as mock_generate: + mock_status.return_value = {"enabled": False} + mock_generate.return_value = { + "secret": "TEST_SECRET", + "qr_code": "data:image/png;base64,test" + } + + response = test_client.post( + '/console/api/account/mfa/setup', + headers=auth_header + ) + + assert response.status_code == 200 + data = response.json + assert data["secret"] == "TEST_SECRET" + assert data["qr_code"] == "data:image/png;base64,test" + mock_generate.assert_called_once_with(setup_account) + + def test_mfa_setup_init_already_enabled(self, test_client, setup_account, auth_header): + """Test MFA setup initialization when already enabled.""" + with patch.object(MFAService, 'get_mfa_status') as mock_status: + mock_status.return_value = {"enabled": True, "setup_at": "2024-01-01T00:00:00"} + + response = test_client.post( + '/console/api/account/mfa/setup', + headers=auth_header + ) + + assert response.status_code == 400 + data = response.json + assert data["error"] == "MFA is already enabled" + + def test_mfa_setup_complete_success(self, test_client, setup_account, auth_header): + """Test successful MFA setup completion.""" + with patch.object(MFAService, 'setup_mfa') as mock_setup: + mock_setup.return_value = { + "message": "MFA has been successfully enabled", + "backup_codes": ["CODE1", "CODE2", "CODE3", "CODE4", "CODE5", "CODE6", "CODE7", "CODE8"], + "setup_at": "2024-01-01T00:00:00" + } + + response = test_client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_header, + json={"totp_token": "123456"} + ) + + assert response.status_code == 200 + data = response.json + assert data["message"] == "MFA has been successfully enabled" + assert len(data["backup_codes"]) == 8 + mock_setup.assert_called_once_with(setup_account, "123456") + + def test_mfa_setup_complete_missing_token(self, test_client, setup_account, auth_header): + """Test MFA setup completion with missing token.""" + response = test_client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_header, + json={} + ) + + assert response.status_code == 400 + data = response.json + assert "totp_token is required" in data["error"] + + def test_mfa_setup_complete_invalid_token(self, test_client, setup_account, auth_header): + """Test MFA setup completion with invalid token.""" + with patch.object(MFAService, 'setup_mfa') as mock_setup: + mock_setup.side_effect = ValueError("Invalid TOTP token") + + response = test_client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_header, + json={"totp_token": "999999"} + ) + + assert response.status_code == 400 + data = response.json + assert "Invalid TOTP token" in data["error"] + + def test_mfa_disable_success(self, test_client, setup_account, auth_header): + """Test successful MFA disable.""" + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.return_value = {"message": "MFA has been disabled"} + + response = test_client.post( + '/console/api/account/mfa/disable', + headers=auth_header, + json={"password": "test_password"} + ) + + assert response.status_code == 200 + data = response.json + assert data["message"] == "MFA has been disabled" + mock_disable.assert_called_once_with(setup_account, "test_password") + + def test_mfa_disable_wrong_password(self, test_client, setup_account, auth_header): + """Test MFA disable with wrong password.""" + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.side_effect = ValueError("Invalid password") + + response = test_client.post( + '/console/api/account/mfa/disable', + headers=auth_header, + json={"password": "wrong_password"} + ) + + assert response.status_code == 400 + data = response.json + assert "Invalid password" in data["error"] + + def test_mfa_disable_not_enabled(self, test_client, setup_account, auth_header): + """Test MFA disable when not enabled.""" + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.side_effect = ValueError("MFA is not enabled") + + response = test_client.post( + '/console/api/account/mfa/disable', + headers=auth_header, + json={"password": "test_password"} + ) + + assert response.status_code == 400 + data = response.json + assert "MFA is not enabled" in data["error"] + + def test_mfa_verify_success(self, test_client): + """Test successful MFA verification during login.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + with patch.object(MFAService, 'is_mfa_required') as mock_required: + with patch.object(MFAService, 'authenticate_with_mfa') as mock_verify: + # Mock user exists + from models.account import Account + mock_account = Account( + id="test-id", + email="test@example.com", + name="Test User" + ) + mock_auth.return_value = mock_account + mock_required.return_value = True + mock_verify.return_value = True + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "test@example.com", + "mfa_code": "123456", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "success" + + def test_mfa_verify_invalid_token(self, test_client): + """Test MFA verification with invalid token.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + with patch.object(MFAService, 'is_mfa_required') as mock_required: + with patch.object(MFAService, 'authenticate_with_mfa') as mock_verify: + # Mock user exists + from models.account import Account + mock_account = Account( + id="test-id", + email="test@example.com", + name="Test User" + ) + mock_auth.return_value = mock_account + mock_required.return_value = True + mock_verify.return_value = False + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "test@example.com", + "mfa_code": "999999", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "fail" + assert data["code"] == "mfa_token_invalid" + + def test_mfa_verify_not_required(self, test_client): + """Test MFA verification when MFA is not required.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + with patch.object(MFAService, 'is_mfa_required') as mock_required: + # Mock user exists + from models.account import Account + mock_account = Account( + id="test-id", + email="test@example.com", + name="Test User" + ) + mock_auth.return_value = mock_account + mock_required.return_value = False + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "test@example.com", + "mfa_code": "123456", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "fail" + assert data["code"] == "mfa_not_required" + + def test_mfa_verify_account_not_found(self, test_client): + """Test MFA verification with non-existent account.""" + with patch('services.account_service.AccountService.authenticate') as mock_auth: + mock_auth.return_value = None + + response = test_client.post( + '/console/api/mfa/verify', + json={ + "email": "nonexistent@example.com", + "mfa_code": "123456", + "remember_me": True + } + ) + + assert response.status_code == 200 + data = response.json + assert data["result"] == "fail" + assert data["code"] == "mfa_verify_failed" \ No newline at end of file diff --git a/api/tests/unit_tests/controllers/console/auth/test_mfa_fixed.py b/api/tests/unit_tests/controllers/console/auth/test_mfa_fixed.py new file mode 100644 index 0000000000..fa1fd7b17f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_mfa_fixed.py @@ -0,0 +1,137 @@ +import pytest +from unittest.mock import patch +from flask import Flask +from flask_login import LoginManager + +from models.account import Account, AccountStatus +from extensions.ext_database import db +from services.mfa_service import MFAService + + +class TestMFAEndpointsFixed: + """Test MFA endpoints using proper Flask test client approach.""" + + @pytest.fixture + def setup_flask_app(self, app): + """Set up Flask app with proper login manager.""" + # This fixture uses the app from conftest which already has LoginManager configured + return app + + @pytest.fixture + def test_account(self, setup_flask_app): + """Create a test account.""" + with setup_flask_app.app_context(): + account = Account( + id="test-account-id", + email="test@example.com", + name="Test User", + password="hashed_password", + status=AccountStatus.ACTIVE.value, + password_salt="salt" + ) + db.session.add(account) + db.session.commit() + yield account + # Cleanup + db.session.delete(account) + db.session.commit() + + @pytest.fixture + def auth_headers(self, setup_flask_app, test_account): + """Get authentication headers by simulating login.""" + with setup_flask_app.test_client() as client: + # Mock the authentication to return our test account + with patch('services.account_service.AccountService.authenticate') as mock_auth: + mock_auth.return_value = test_account + + # Perform login to get token + response = client.post('/console/api/login', json={ + 'email': test_account.email, + 'password': 'test_password' + }) + + # Extract token from response + token = response.json.get('data', {}).get('access_token') + return {'Authorization': f'Bearer {token}'} + + def test_mfa_status_success(self, setup_flask_app, test_account, auth_headers): + """Test successful MFA status check.""" + with setup_flask_app.test_client() as client: + with setup_flask_app.app_context(): + # Mock the MFA service + with patch.object(MFAService, 'get_mfa_status') as mock_status: + mock_status.return_value = {"enabled": False, "setup_at": None} + + response = client.get( + '/console/api/account/mfa/status', + headers=auth_headers + ) + + assert response.status_code == 200 + data = response.json + assert data["enabled"] is False + assert data["setup_at"] is None + + def test_mfa_setup_init_success(self, setup_flask_app, test_account, auth_headers): + """Test successful MFA setup initialization.""" + with setup_flask_app.test_client() as client: + with setup_flask_app.app_context(): + # Mock MFA service methods + with patch.object(MFAService, 'get_mfa_status') as mock_status: + with patch.object(MFAService, 'generate_mfa_setup_data') as mock_generate: + mock_status.return_value = {"enabled": False} + mock_generate.return_value = { + "secret": "TEST_SECRET", + "qr_code": "data:image/png;base64,test" + } + + response = client.post( + '/console/api/account/mfa/setup', + headers=auth_headers + ) + + assert response.status_code == 200 + data = response.json + assert data["secret"] == "TEST_SECRET" + assert data["qr_code"] == "data:image/png;base64,test" + + def test_mfa_setup_complete_success(self, setup_flask_app, test_account, auth_headers): + """Test successful MFA setup completion.""" + with setup_flask_app.test_client() as client: + with setup_flask_app.app_context(): + # Mock MFA service + with patch.object(MFAService, 'setup_mfa') as mock_setup: + mock_setup.return_value = { + "message": "MFA has been successfully enabled", + "backup_codes": ["CODE1", "CODE2", "CODE3", "CODE4", "CODE5", "CODE6", "CODE7", "CODE8"], + "setup_at": "2024-01-01T00:00:00" + } + + response = client.post( + '/console/api/account/mfa/setup/complete', + headers=auth_headers, + json={"totp_token": "123456"} + ) + + assert response.status_code == 200 + data = response.json + assert data["message"] == "MFA has been successfully enabled" + assert len(data["backup_codes"]) == 8 + + def test_mfa_disable_success(self, setup_flask_app, test_account, auth_headers): + """Test successful MFA disable.""" + with setup_flask_app.test_client() as client: + with setup_flask_app.app_context(): + # Mock MFA service + with patch.object(MFAService, 'disable_mfa') as mock_disable: + mock_disable.return_value = {"message": "MFA has been disabled"} + + response = client.post( + '/console/api/account/mfa/disable', + headers=auth_headers, + json={"password": "test_password"} + ) + + assert response.status_code == 200 + data = response.json + assert data["message"] == "MFA has been disabled" \ No newline at end of file diff --git a/api/tests/unit_tests/controllers/console/auth/test_mfa_minimal.py b/api/tests/unit_tests/controllers/console/auth/test_mfa_minimal.py new file mode 100644 index 0000000000..6473f6e58b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_mfa_minimal.py @@ -0,0 +1,52 @@ +"""Minimal unit tests for MFA controllers to verify they're importable and basic structure.""" + +import pytest +from unittest.mock import MagicMock + +from controllers.console.auth.mfa import ( + MFASetupInitApi, + MFASetupCompleteApi, + MFADisableApi, + MFAStatusApi, + MFAVerifyApi +) + + +class TestMFAControllersMinimal: + """Minimal tests to verify MFA controllers are properly defined.""" + + def test_mfa_controllers_exist(self): + """Test that all MFA controller classes exist.""" + assert MFASetupInitApi is not None + assert MFASetupCompleteApi is not None + assert MFADisableApi is not None + assert MFAStatusApi is not None + assert MFAVerifyApi is not None + + def test_mfa_controllers_have_methods(self): + """Test that MFA controllers have expected methods.""" + # Setup Init has both GET and POST + assert hasattr(MFASetupInitApi, 'get') + assert hasattr(MFASetupInitApi, 'post') + + # Setup Complete has POST + assert hasattr(MFASetupCompleteApi, 'post') + + # Disable has POST + assert hasattr(MFADisableApi, 'post') + + # Status has GET + assert hasattr(MFAStatusApi, 'get') + + # Verify has POST + assert hasattr(MFAVerifyApi, 'post') + + def test_mfa_controller_inheritance(self): + """Test that MFA controllers inherit from Resource.""" + from flask_restful import Resource + + assert issubclass(MFASetupInitApi, Resource) + assert issubclass(MFASetupCompleteApi, Resource) + assert issubclass(MFADisableApi, Resource) + assert issubclass(MFAStatusApi, Resource) + assert issubclass(MFAVerifyApi, Resource) \ No newline at end of file diff --git a/api/tests/unit_tests/services/test_mfa_service.py b/api/tests/unit_tests/services/test_mfa_service.py new file mode 100644 index 0000000000..a820b8c93e --- /dev/null +++ b/api/tests/unit_tests/services/test_mfa_service.py @@ -0,0 +1,370 @@ +import json +import unittest +from unittest.mock import Mock, patch +from datetime import datetime, timezone + +import pytest + +from models.account import Account, AccountMFASettings +from services.mfa_service import MFAService + + +class TestMFAService(unittest.TestCase): + def setUp(self): + self.account = Mock(spec=Account) + self.account.id = "test-account-id" + self.account.email = "test@example.com" + self.account.password = "hashed_password" + self.account.password_salt = "salt" + + self.mfa_settings = Mock(spec=AccountMFASettings) + self.mfa_settings.account_id = self.account.id + self.mfa_settings.enabled = False + self.mfa_settings.secret = None + self.mfa_settings.backup_codes = None + self.mfa_settings.setup_at = None + + def test_generate_secret(self): + """Test secret generation.""" + secret = MFAService.generate_secret() + self.assertIsInstance(secret, str) + self.assertEqual(len(secret), 32) # Base32 length + + def test_generate_backup_codes(self): + """Test backup codes generation.""" + codes = MFAService.generate_backup_codes() + self.assertEqual(len(codes), 8) + for code in codes: + self.assertIsInstance(code, str) + self.assertEqual(len(code), 8) # 4 hex bytes = 8 chars + + @patch('pyotp.TOTP') + def test_verify_totp_valid(self, mock_totp_class): + """Test TOTP verification with valid token.""" + mock_totp = Mock() + mock_totp.verify.return_value = True + mock_totp_class.return_value = mock_totp + + result = MFAService.verify_totp("test_secret", "123456") + + self.assertTrue(result) + mock_totp.verify.assert_called_once_with("123456", valid_window=1) + + @patch('pyotp.TOTP') + def test_verify_totp_invalid(self, mock_totp_class): + """Test TOTP verification with invalid token.""" + mock_totp = Mock() + mock_totp.verify.return_value = False + mock_totp_class.return_value = mock_totp + + result = MFAService.verify_totp("test_secret", "invalid") + + self.assertFalse(result) + + def test_verify_totp_no_secret(self): + """Test TOTP verification with no secret.""" + result = MFAService.verify_totp(None, "123456") + self.assertFalse(result) + + @patch('services.mfa_service.db.session') + def test_get_or_create_mfa_settings_existing(self, mock_session): + """Test getting existing MFA settings.""" + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + + result = MFAService.get_or_create_mfa_settings(self.account) + + self.assertEqual(result, self.mfa_settings) + mock_session.query.assert_called_once() + + @patch('services.mfa_service.db.session') + def test_get_or_create_mfa_settings_new(self, mock_session): + """Test creating new MFA settings.""" + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + result = MFAService.get_or_create_mfa_settings(self.account) + + # Check that new settings were created + self.assertIsInstance(result, AccountMFASettings) + self.assertEqual(result.account_id, self.account.id) + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + @patch('services.mfa_service.db.session') + def test_verify_backup_code_valid(self, mock_session): + """Test backup code verification with valid code.""" + self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"]) + + result = MFAService.verify_backup_code(self.mfa_settings, "abcd1234") # Test case insensitive + + self.assertTrue(result) + # Check that the code was removed + remaining_codes = json.loads(self.mfa_settings.backup_codes) + self.assertNotIn("ABCD1234", remaining_codes) + self.assertIn("EFGH5678", remaining_codes) + mock_session.commit.assert_called_once() + + def test_verify_backup_code_invalid(self): + """Test backup code verification with invalid code.""" + self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"]) + + result = MFAService.verify_backup_code(self.mfa_settings, "INVALID") + + self.assertFalse(result) + + def test_verify_backup_code_no_codes(self): + """Test backup code verification with no backup codes.""" + self.mfa_settings.backup_codes = None + + result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234") + + self.assertFalse(result) + + @patch('services.mfa_service.MFAService.get_or_create_mfa_settings') + @patch('services.mfa_service.MFAService.verify_totp') + @patch('services.mfa_service.MFAService.generate_backup_codes') + @patch('services.mfa_service.db.session') + def test_setup_mfa_success(self, mock_session, mock_gen_codes, mock_verify, mock_get_settings): + """Test successful MFA setup.""" + mock_get_settings.return_value = self.mfa_settings + self.mfa_settings.secret = "test_secret" + mock_verify.return_value = True + mock_gen_codes.return_value = ["CODE1", "CODE2"] + + result = MFAService.setup_mfa(self.account, "123456") + + self.assertTrue(self.mfa_settings.enabled) + self.assertEqual(self.mfa_settings.backup_codes, json.dumps(["CODE1", "CODE2"])) + self.assertIsNotNone(self.mfa_settings.setup_at) + self.assertEqual(result["backup_codes"], ["CODE1", "CODE2"]) + + @patch('services.mfa_service.MFAService.get_or_create_mfa_settings') + def test_setup_mfa_already_enabled(self, mock_get_settings): + """Test MFA setup when already enabled.""" + self.mfa_settings.enabled = True + mock_get_settings.return_value = self.mfa_settings + + with self.assertRaises(ValueError) as context: + MFAService.setup_mfa(self.account, "123456") + + self.assertIn("already enabled", str(context.exception)) + + @patch('services.mfa_service.MFAService.get_or_create_mfa_settings') + def test_setup_mfa_no_secret(self, mock_get_settings): + """Test MFA setup without secret.""" + mock_get_settings.return_value = self.mfa_settings + + with self.assertRaises(ValueError) as context: + MFAService.setup_mfa(self.account, "123456") + + self.assertIn("secret not generated", str(context.exception)) + + @patch('services.mfa_service.MFAService.get_or_create_mfa_settings') + @patch('services.mfa_service.MFAService.verify_totp') + def test_setup_mfa_invalid_token(self, mock_verify, mock_get_settings): + """Test MFA setup with invalid TOTP token.""" + mock_get_settings.return_value = self.mfa_settings + self.mfa_settings.secret = "test_secret" + mock_verify.return_value = False + + with self.assertRaises(ValueError) as context: + MFAService.setup_mfa(self.account, "invalid") + + self.assertIn("Invalid TOTP token", str(context.exception)) + + @patch('services.mfa_service.db.session') + def test_is_mfa_required_enabled(self, mock_session): + """Test MFA requirement check when enabled.""" + self.mfa_settings.enabled = True + self.mfa_settings.secret = "test_secret" + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + + result = MFAService.is_mfa_required(self.account) + + self.assertTrue(result) + + @patch('services.mfa_service.db.session') + def test_is_mfa_required_disabled(self, mock_session): + """Test MFA requirement check when disabled.""" + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + + result = MFAService.is_mfa_required(self.account) + + self.assertFalse(result) + + @patch('services.mfa_service.db.session') + def test_is_mfa_required_no_settings(self, mock_session): + """Test MFA requirement check with no settings.""" + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + result = MFAService.is_mfa_required(self.account) + + self.assertFalse(result) + + @patch('services.mfa_service.db.session') + @patch('services.mfa_service.MFAService.verify_totp') + @patch('services.mfa_service.MFAService.verify_backup_code') + def test_authenticate_with_mfa_totp_success(self, mock_verify_backup, mock_verify_totp, mock_session): + """Test MFA authentication with valid TOTP.""" + self.mfa_settings.enabled = True + self.mfa_settings.secret = "test_secret" + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + mock_verify_totp.return_value = True + + result = MFAService.authenticate_with_mfa(self.account, "123456") + + self.assertTrue(result) + mock_verify_totp.assert_called_once_with("test_secret", "123456") + mock_verify_backup.assert_not_called() + + @patch('services.mfa_service.db.session') + @patch('services.mfa_service.MFAService.verify_totp') + @patch('services.mfa_service.MFAService.verify_backup_code') + def test_authenticate_with_mfa_backup_success(self, mock_verify_backup, mock_verify_totp, mock_session): + """Test MFA authentication with valid backup code.""" + self.mfa_settings.enabled = True + self.mfa_settings.secret = "test_secret" + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + mock_verify_totp.return_value = False + mock_verify_backup.return_value = True + + result = MFAService.authenticate_with_mfa(self.account, "BACKUP123") + + self.assertTrue(result) + mock_verify_totp.assert_called_once_with("test_secret", "BACKUP123") + mock_verify_backup.assert_called_once_with(self.mfa_settings, "BACKUP123") + + @patch('services.mfa_service.db.session') + def test_authenticate_with_mfa_disabled(self, mock_session): + """Test MFA authentication when disabled.""" + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + + result = MFAService.authenticate_with_mfa(self.account, "123456") + + self.assertTrue(result) + + @patch('services.mfa_service.db.session') + def test_get_mfa_status_enabled(self, mock_session): + """Test getting MFA status when enabled.""" + self.mfa_settings.enabled = True + self.mfa_settings.setup_at = datetime(2025, 1, 1, 12, 0, 0) + self.mfa_settings.backup_codes = json.dumps(["CODE1", "CODE2"]) + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + + result = MFAService.get_mfa_status(self.account) + + expected = { + "enabled": True, + "setup_at": "2025-01-01T12:00:00", + "has_backup_codes": True + } + self.assertEqual(result, expected) + + @patch('services.mfa_service.db.session') + def test_get_mfa_status_no_settings(self, mock_session): + """Test getting MFA status with no settings.""" + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + result = MFAService.get_mfa_status(self.account) + + expected = { + "enabled": False, + "setup_at": None, + "has_backup_codes": False + } + self.assertEqual(result, expected) + + @patch('qrcode.QRCode') + @patch('pyotp.TOTP') + def test_generate_qr_code(self, mock_totp_class, mock_qr_class): + """Test QR code generation.""" + # Mock TOTP + mock_totp = Mock() + mock_totp.provisioning_uri.return_value = "otpauth://totp/test" + mock_totp_class.return_value = mock_totp + + # Mock QR code + mock_qr = Mock() + mock_img = Mock() + mock_qr.make_image.return_value = mock_img + mock_qr_class.return_value = mock_qr + + # Mock image buffer + with patch('io.BytesIO') as mock_buffer, \ + patch('base64.b64encode') as mock_b64: + mock_b64.return_value.decode.return_value = "base64data" + + result = MFAService.generate_qr_code(self.account, "test_secret") + + self.assertEqual(result, "data:image/png;base64,base64data") + mock_totp.provisioning_uri.assert_called_once_with( + name=self.account.email, + issuer_name="Dify" + ) + + @patch('libs.password.compare_password') + @patch('services.mfa_service.db.session') + def test_disable_mfa_success(self, mock_session, mock_compare_password): + """Test successful MFA disable.""" + mock_compare_password.return_value = True + mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings + + result = MFAService.disable_mfa(self.account, "correct_password") + + self.assertTrue(result) + self.assertFalse(self.mfa_settings.enabled) + self.assertIsNone(self.mfa_settings.secret) + self.assertIsNone(self.mfa_settings.backup_codes) + self.assertIsNone(self.mfa_settings.setup_at) + mock_session.commit.assert_called_once() + + @patch('libs.password.compare_password') + def test_disable_mfa_wrong_password(self, mock_compare_password): + """Test MFA disable with wrong password.""" + mock_compare_password.return_value = False + + result = MFAService.disable_mfa(self.account, "wrong_password") + + self.assertFalse(result) + + @patch('libs.password.compare_password') + @patch('services.mfa_service.db.session') + def test_disable_mfa_no_settings(self, mock_session, mock_compare_password): + """Test MFA disable when no settings exist.""" + mock_compare_password.return_value = True + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + result = MFAService.disable_mfa(self.account, "correct_password") + + self.assertTrue(result) # Already disabled + + @patch('services.mfa_service.MFAService.get_or_create_mfa_settings') + @patch('services.mfa_service.MFAService.generate_secret') + @patch('services.mfa_service.MFAService.generate_qr_code') + @patch('services.mfa_service.db.session') + def test_generate_mfa_setup_data_success(self, mock_session, mock_gen_qr, mock_gen_secret, mock_get_settings): + """Test successful MFA setup data generation.""" + mock_get_settings.return_value = self.mfa_settings + mock_gen_secret.return_value = "NEWSECRET123" + mock_gen_qr.return_value = "data:image/png;base64,qrdata" + + result = MFAService.generate_mfa_setup_data(self.account) + + self.assertEqual(result["secret"], "NEWSECRET123") + self.assertEqual(result["qr_code"], "data:image/png;base64,qrdata") + self.assertEqual(self.mfa_settings.secret, "NEWSECRET123") + mock_session.commit.assert_called_once() + + @patch('services.mfa_service.MFAService.get_or_create_mfa_settings') + def test_generate_mfa_setup_data_already_enabled(self, mock_get_settings): + """Test MFA setup data generation when already enabled.""" + self.mfa_settings.enabled = True + mock_get_settings.return_value = self.mfa_settings + + with self.assertRaises(ValueError) as context: + MFAService.generate_mfa_setup_data(self.account) + + self.assertIn("already enabled", str(context.exception)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/api/uv.lock b/api/uv.lock index 21b6b20f53..ccf429b60b 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1281,10 +1281,12 @@ dependencies = [ { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pyotp" }, { name = "pypdfium2" }, { name = "python-docx" }, { name = "python-dotenv" }, { name = "pyyaml" }, + { name = "qrcode", extra = ["pil"] }, { name = "readabilipy" }, { name = "redis", extra = ["hiredis"] }, { name = "resend" }, @@ -1463,10 +1465,12 @@ requires-dist = [ { name = "pydantic-extra-types", specifier = "~=2.10.3" }, { name = "pydantic-settings", specifier = "~=2.9.1" }, { name = "pyjwt", specifier = "~=2.8.0" }, + { name = "pyotp", specifier = "~=2.9.0" }, { name = "pypdfium2", specifier = "==4.30.0" }, { name = "python-docx", specifier = "~=1.1.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, { name = "pyyaml", specifier = "~=6.0.1" }, + { name = "qrcode", extras = ["pil"], specifier = "~=7.4.2" }, { name = "readabilipy", specifier = "~=0.3.0" }, { name = "redis", extras = ["hiredis"], specifier = "~=6.1.0" }, { name = "resend", specifier = "~=2.9.0" }, @@ -4469,6 +4473,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/68/ecb21b74c974e7be7f9034e205d08db62d614ff5c221581ae96d37ef853e/pyobvector-0.1.14-py3-none-any.whl", hash = "sha256:828e0bec49a177355b70c7a1270af3b0bf5239200ee0d096e4165b267eeff97c", size = 35526, upload-time = "2024-11-20T11:46:16.809Z" }, ] +[[package]] +name = "pyotp" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" }, +] + [[package]] name = "pypandoc" version = "1.15" @@ -4522,6 +4535,15 @@ version = "0.48.9" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/c7/2c/94ed7b91db81d61d7096ac8f2d325ec562fc75e35f3baea8749c85b28784/PyPika-0.48.9.tar.gz", hash = "sha256:838836a61747e7c8380cd1b7ff638694b7a7335345d0f559b04b2cd832ad5378", size = 67259, upload-time = "2022-03-15T11:22:57.066Z" } +[[package]] +name = "pypng" +version = "0.20220715.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/93/cd/112f092ec27cca83e0516de0a3368dbd9128c187fb6b52aaaa7cde39c96d/pypng-0.20220715.0.tar.gz", hash = "sha256:739c433ba96f078315de54c0db975aee537cbc3e1d0ae4ed9aab0ca1e427e2c1", size = 128992, upload-time = "2022-07-15T14:11:05.301Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3e/b9/3766cc361d93edb2ce81e2e1f87dd98f314d7d513877a342d31b30741680/pypng-0.20220715.0-py3-none-any.whl", hash = "sha256:4a43e969b8f5aaafb2a415536c1a8ec7e341cd6a3f957fd5b5f32a4cfeed902c", size = 58057, upload-time = "2022-07-15T14:11:03.713Z" }, +] + [[package]] name = "pyproject-hooks" version = "1.2.0" @@ -4807,6 +4829,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/fa/5abd82cde353f1009c068cca820195efd94e403d261b787e78ea7a9c8318/qdrant_client-1.9.0-py3-none-any.whl", hash = "sha256:ee02893eab1f642481b1ac1e38eb68ec30bab0f673bef7cc05c19fa5d2cbf43e", size = 229258, upload-time = "2024-04-22T13:35:46.81Z" }, ] +[[package]] +name = "qrcode" +version = "7.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "pypng" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/35/ad6d4c5a547fe9a5baf85a9edbafff93fc6394b014fab30595877305fa59/qrcode-7.4.2.tar.gz", hash = "sha256:9dd969454827e127dbd93696b20747239e6d540e082937c90f14ac95b30f5845", size = 535974, upload-time = "2023-02-05T22:11:46.548Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/79/aaf0c1c7214f2632badb2771d770b1500d3d7cbdf2590ae62e721ec50584/qrcode-7.4.2-py3-none-any.whl", hash = "sha256:581dca7a029bcb2deef5d01068e39093e80ef00b4a61098a2182eac59d01643a", size = 46197, upload-time = "2023-02-05T22:11:43.4Z" }, +] + +[package.optional-dependencies] +pil = [ + { name = "pillow" }, +] + [[package]] name = "rapidfuzz" version = "3.13.0" diff --git a/web/.npmrc b/web/.npmrc new file mode 100644 index 0000000000..8280b762ae --- /dev/null +++ b/web/.npmrc @@ -0,0 +1 @@ +store-dir=../.pnpm-store \ No newline at end of file diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 9b36fc6020..e3fea6b351 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -120,7 +120,9 @@ export default function AppSelector() {
setShowAccountSettingModal({ payload: 'members' })}> + )} onClick={() => { + setShowAccountSettingModal({ payload: 'members' }) + }}>
{t('common.userProfile.settings')}
diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index b2a3c8245b..3a979e6715 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -16,10 +16,13 @@ import { RiPuzzle2Fill, RiPuzzle2Line, RiTranslate2, + RiShieldKeyholeLine, + RiShieldKeyholeFill, } from '@remixicon/react' import Button from '../../base/button' import MembersPage from './members-page' import LanguagePage from './language-page' +import MFAPage from './mfa-page' import ApiBasedExtensionPage from './api-based-extension-page' import DataSourcePage from './data-source-page' import ModelProviderPage from './model-provider-page' @@ -53,10 +56,13 @@ export default function AccountSetting({ onCancel, activeTab = 'members', }: IAccountSettingProps) { - const [activeMenu, setActiveMenu] = useState(activeTab) const { t } = useTranslation() const { enableBilling, enableReplaceWebAppLogo } = useProviderContext() const { isCurrentWorkspaceDatasetOperator } = useAppContext() + + // Set appropriate default tab based on user role + const defaultTab = isCurrentWorkspaceDatasetOperator ? 'mfa' : activeTab + const [activeMenu, setActiveMenu] = useState(defaultTab) const workplaceGroupItems = (() => { if (isCurrentWorkspaceDatasetOperator) @@ -116,6 +122,12 @@ export default function AccountSetting({ key: 'account-group', name: t('common.settings.generalGroup'), items: [ + { + key: 'mfa', + name: t('common.settings.mfa'), + icon: , + activeIcon: , + }, { key: 'language', name: t('common.settings.language'), @@ -125,8 +137,11 @@ export default function AccountSetting({ ], }, ] + const scrollRef = useRef(null) const [scrolled, setScrolled] = useState(false) + + useEffect(() => { const targetElement = scrollRef.current const scrollHandle = (e: Event) => { @@ -155,7 +170,7 @@ export default function AccountSetting({ { menuItems.map(menuItem => (
- {!isCurrentWorkspaceDatasetOperator && ( + {menuItem.items.length > 0 && (
{menuItem.name}
)}
@@ -219,6 +234,7 @@ export default function AccountSetting({ {activeMenu === 'data-source' && } {activeMenu === 'api-based-extension' && } {activeMenu === 'custom' && } + {activeMenu === 'mfa' && } {activeMenu === 'language' && }
diff --git a/web/app/components/header/account-setting/mfa-page.test.tsx b/web/app/components/header/account-setting/mfa-page.test.tsx new file mode 100644 index 0000000000..a81338e90f --- /dev/null +++ b/web/app/components/header/account-setting/mfa-page.test.tsx @@ -0,0 +1,346 @@ +import React from 'react' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import '@testing-library/jest-dom' + +// Mock the service base to avoid ky import issues +jest.mock('@/service/base', () => ({ + get: jest.fn(), + post: jest.fn(), + put: jest.fn(), + del: jest.fn(), +})) + +// Mock the translation hook +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +import MFAPage from './mfa-page' + +// Mock the Toast component +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(), + }, +})) + +// Mock Modal component +jest.mock('@/app/components/base/modal', () => ({ + __esModule: true, + default: ({ isOpen, onClose, children }: any) => + isOpen ?
{children}
: null, +})) + +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' + +// Create a test wrapper component +const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + + return ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) +} + +describe('MFAPage Component', () => { + let wrapper: ReturnType + + beforeEach(() => { + jest.clearAllMocks() + wrapper = createWrapper() + }) + + test('renders loading state initially', () => { + const { get } = require('@/service/base') + get.mockImplementation(() => new Promise(() => {})) // Never resolves + + render(, { wrapper }) + + expect(screen.getByText('Loading...')).toBeInTheDocument() + }) + + test('renders enable button when MFA is disabled', async () => { + const { get } = require('@/service/base') + get.mockResolvedValue({ enabled: false }) + + render(, { wrapper }) + + await waitFor(() => { + expect(screen.getByText('mfa.enable')).toBeInTheDocument() + }) + }) + + test('renders disable button when MFA is enabled', async () => { + const { get } = require('@/service/base') + get.mockResolvedValue({ + enabled: true, + setup_at: '2025-01-01T12:00:00' + }) + + render(, { wrapper }) + + await waitFor(() => { + expect(screen.getByText('mfa.disable')).toBeInTheDocument() + }) + }) + + test('opens setup modal when enable button is clicked', async () => { + const { get, post } = require('@/service/base') + get.mockResolvedValue({ enabled: false }) + post.mockResolvedValue({ + secret: 'TEST_SECRET', + qr_code: 'data:image/png;base64,test' + }) + + render(, { wrapper }) + + await waitFor(() => { + expect(screen.getByText('mfa.enable')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('mfa.enable')) + + await waitFor(() => { + expect(screen.getByTestId('modal')).toBeInTheDocument() + expect(screen.getByText('mfa.setupTitle')).toBeInTheDocument() + }) + }) + + test('completes MFA setup successfully', async () => { + const { get, post } = require('@/service/base') + const Toast = require('@/app/components/base/toast').default + + get.mockResolvedValue({ enabled: false }) + post.mockImplementation((url) => { + if (url.includes('/setup') && !url.includes('/complete')) { + return Promise.resolve({ + secret: 'TEST_SECRET', + qr_code: 'data:image/png;base64,test' + }) + } else if (url.includes('/setup/complete')) { + return Promise.resolve({ + message: 'MFA setup successfully', + backup_codes: ['CODE1', 'CODE2', 'CODE3', 'CODE4', 'CODE5', 'CODE6', 'CODE7', 'CODE8'], + setup_at: '2025-01-01T12:00:00' + }) + } + }) + + render(, { wrapper }) + + // Click enable + await waitFor(() => { + fireEvent.click(screen.getByText('mfa.enable')) + }) + + // Wait for QR code to be displayed + await waitFor(() => { + expect(screen.getByAltText('MFA QR Code')).toBeInTheDocument() + }) + + // Enter TOTP code + const inputs = screen.getAllByRole('textbox') + // Simulate entering '123456' + '123456'.split('').forEach((digit, index) => { + fireEvent.change(inputs[index], { target: { value: digit } }) + }) + + // Click verify button + const verifyButton = screen.getByRole('button', { name: /verify|mfa.verify/i }) + fireEvent.click(verifyButton) + + await waitFor(() => { + expect(Toast.notify).toHaveBeenCalledWith({ + type: 'success', + message: 'mfa.setupSuccess' + }) + }) + }) + + test('shows error when setup fails', async () => { + const { get, post } = require('@/service/base') + const Toast = require('@/app/components/base/toast').default + + get.mockResolvedValue({ enabled: false }) + post.mockImplementation((url) => { + if (url.includes('/setup') && !url.includes('/complete')) { + return Promise.resolve({ + secret: 'TEST_SECRET', + qr_code: 'data:image/png;base64,test' + }) + } else if (url.includes('/setup/complete')) { + return Promise.reject(new Error('Invalid TOTP token')) + } + }) + + render(, { wrapper }) + + // Click enable + await waitFor(() => { + fireEvent.click(screen.getByText('mfa.enable')) + }) + + // Wait for QR code + await waitFor(() => { + expect(screen.getByAltText('MFA QR Code')).toBeInTheDocument() + }) + + // Enter wrong TOTP code + const inputs = screen.getAllByRole('textbox') + '000000'.split('').forEach((digit, index) => { + fireEvent.change(inputs[index], { target: { value: digit } }) + }) + + // Click verify + const verifyButton = screen.getByRole('button', { name: /verify|mfa.verify/i }) + fireEvent.click(verifyButton) + + await waitFor(() => { + expect(Toast.notify).toHaveBeenCalledWith({ + type: 'error', + message: 'Invalid TOTP token' + }) + }) + }) + + test('disables MFA successfully', async () => { + const { get, post } = require('@/service/base') + const Toast = require('@/app/components/base/toast').default + + get.mockResolvedValue({ + enabled: true, + setup_at: '2025-01-01T12:00:00' + }) + post.mockImplementation((url) => { + if (url.includes('/disable')) { + return Promise.resolve({ + success: true, + message: 'MFA disabled successfully' + }) + } + }) + + render(, { wrapper }) + + // Click disable + await waitFor(() => { + fireEvent.click(screen.getByText('mfa.disable')) + }) + + // Modal should open + await waitFor(() => { + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + // Enter password + const passwordInput = screen.getByPlaceholderText('mfa.enterYourPassword') + fireEvent.change(passwordInput, { target: { value: 'password123' } }) + + // Click confirm + const confirmButton = screen.getByText('common.operation.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(Toast.notify).toHaveBeenCalledWith({ + type: 'success', + message: 'mfa.disabledSuccessfully' + }) + }) + }) + + test('shows error when disable fails with wrong password', async () => { + const { get, post } = require('@/service/base') + const Toast = require('@/app/components/base/toast').default + + get.mockResolvedValue({ + enabled: true, + setup_at: '2025-01-01T12:00:00' + }) + post.mockImplementation((url) => { + if (url.includes('/disable')) { + return Promise.reject(new Error('Invalid password')) + } + }) + + render(, { wrapper }) + + // Click disable + await waitFor(() => { + fireEvent.click(screen.getByText('mfa.disable')) + }) + + // Enter wrong password + const passwordInput = screen.getByPlaceholderText('mfa.enterYourPassword') + fireEvent.change(passwordInput, { target: { value: 'wrongpassword' } }) + + // Click confirm + const confirmButton = screen.getByText('common.operation.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(Toast.notify).toHaveBeenCalledWith({ + type: 'error', + message: 'Invalid password' + }) + }) + }) + + test('handles backup codes display correctly', async () => { + const { get, post } = require('@/service/base') + + get.mockResolvedValue({ enabled: false }) + post.mockImplementation((url) => { + if (url.includes('/setup') && !url.includes('/complete')) { + return Promise.resolve({ + secret: 'TEST_SECRET', + qr_code: 'data:image/png;base64,test' + }) + } else if (url.includes('/setup/complete')) { + return Promise.resolve({ + message: 'MFA setup successfully', + backup_codes: ['ABCD1234', 'EFGH5678', 'IJKL9012', 'MNOP3456', 'QRST7890', 'UVWX1234', 'YZAB5678', 'CDEF9012'], + setup_at: '2025-01-01T12:00:00' + }) + } + }) + + render(, { wrapper }) + + // Setup MFA + await waitFor(() => { + fireEvent.click(screen.getByText('mfa.enable')) + }) + + await waitFor(() => { + expect(screen.getByAltText('MFA QR Code')).toBeInTheDocument() + }) + + // Enter TOTP code + const inputs = screen.getAllByRole('textbox') + '123456'.split('').forEach((digit, index) => { + fireEvent.change(inputs[index], { target: { value: digit } }) + }) + + // Verify + const verifyButton = screen.getByRole('button', { name: /verify|mfa.verify/i }) + fireEvent.click(verifyButton) + + // Check backup codes are displayed + await waitFor(() => { + expect(screen.getByText('mfa.backupCodes')).toBeInTheDocument() + expect(screen.getByText('ABCD1234')).toBeInTheDocument() + expect(screen.getByText('EFGH5678')).toBeInTheDocument() + }) + }) +}) \ No newline at end of file diff --git a/web/app/components/header/account-setting/mfa-page.tsx b/web/app/components/header/account-setting/mfa-page.tsx new file mode 100644 index 0000000000..89bc155bcc --- /dev/null +++ b/web/app/components/header/account-setting/mfa-page.tsx @@ -0,0 +1,306 @@ +'use client' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { RiShieldKeyholeLine, RiCheckboxCircleFill, RiLoader2Line } from '@remixicon/react' +import Toast from '../../base/toast' +import Button from '../../base/button' +import Input from '../../base/input' +import Modal from '../../base/modal' +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' + +import { get, post } from '@/service/base' + +// API service functions +const mfaService = { + getStatus: async () => { + return get<{ + enabled: boolean + setup_at: string | null + }>('/account/mfa/status') + }, + + initSetup: async () => { + return post<{ + secret: string + qr_code: string + }>('/account/mfa/setup', { body: {} }) + }, + + completeSetup: async (totpToken: string, password: string) => { + return post<{ + message: string + backup_codes: string[] + setup_at: string + }>('/account/mfa/setup/complete', { + body: { totp_token: totpToken } + }) + }, + + disable: async (password: string) => { + return post('/account/mfa/disable', { + body: { password } + }) + }, +} + +export default function MFAPage() { + const { t } = useTranslation() + const queryClient = useQueryClient() + + // State + const [isSetupModalOpen, setIsSetupModalOpen] = useState(false) + const [isDisableModalOpen, setIsDisableModalOpen] = useState(false) + const [setupStep, setSetupStep] = useState<'qr' | 'verify' | 'backup'>('qr') + const [totpToken, setTotpToken] = useState('') + const [password, setPassword] = useState('') + const [qrData, setQrData] = useState<{ secret: string; qr_code: string } | null>(null) + const [backupCodes, setBackupCodes] = useState([]) + + // Query MFA status + const { data: mfaStatus, isLoading } = useQuery({ + queryKey: ['mfa-status'], + queryFn: mfaService.getStatus, + }) + + + // Mutations + const initSetupMutation = useMutation({ + mutationFn: mfaService.initSetup, + onSuccess: (data) => { + setQrData(data) + setIsSetupModalOpen(true) + setSetupStep('qr') + }, + onError: () => { + Toast.notify({ type: 'error', message: t('common.somethingWentWrong') }) + }, + }) + + const completeSetupMutation = useMutation({ + mutationFn: ({ totpToken, password }: { totpToken: string; password: string }) => + mfaService.completeSetup(totpToken, password), + onSuccess: (data) => { + setBackupCodes(data.backup_codes) + setSetupStep('backup') + queryClient.invalidateQueries({ queryKey: ['mfa-status'] }) + }, + onError: () => { + Toast.notify({ type: 'error', message: t('mfa.invalidToken') }) + }, + }) + + const disableMutation = useMutation({ + mutationFn: mfaService.disable, + onSuccess: () => { + setIsDisableModalOpen(false) + queryClient.invalidateQueries({ queryKey: ['mfa-status'] }) + Toast.notify({ type: 'success', message: t('mfa.disabledSuccess') }) + }, + onError: () => { + Toast.notify({ type: 'error', message: t('mfa.invalidPassword') }) + }, + }) + + const handleSetupStart = () => { + initSetupMutation.mutate() + } + + const handleVerifyToken = () => { + if (totpToken.length !== 6) { + Toast.notify({ type: 'error', message: t('mfa.tokenLength') }) + return + } + completeSetupMutation.mutate({ totpToken, password: '' }) + } + + const handleDisable = () => { + disableMutation.mutate(password) + } + + const handleCopyBackupCodes = () => { + const codesText = backupCodes.join('\n') + navigator.clipboard.writeText(codesText) + Toast.notify({ type: 'success', message: t('mfa.copied') }) + } + + if (isLoading) { + return ( +
+ +
+ ) + } + + return ( +
+
+
+ +
+
{t('mfa.description')}
+
+ {t('mfa.securityTip')} +
+
+ +
+
+
+
+ +
+
+
{t('mfa.authenticatorApp')}
+
{t('mfa.authenticatorDescription')}
+
+
+
+ {mfaStatus?.enabled && ( + + )} + +
+
+ + {mfaStatus?.enabled && mfaStatus?.setup_at && ( +
+ {t('mfa.enabledAt', { date: new Date(mfaStatus.setup_at).toLocaleDateString() })} +
+ )} +
+ + {/* Setup Modal */} + setIsSetupModalOpen(false)} + title={t('mfa.setupTitle')} + className="!max-w-md" + > + {setupStep === 'qr' && qrData && ( +
+

{t('mfa.scanQRCode')}

+
+ MFA QR Code +
+
+

{t('mfa.secretKey')}

+ {qrData.secret} +
+ +
+ )} + + {setupStep === 'verify' && ( +
+

{t('mfa.enterToken')}

+ setTotpToken(e.target.value)} + placeholder="000000" + maxLength={6} + className="text-center text-2xl font-mono" + /> + +
+ )} + + {setupStep === 'backup' && ( +
+
+

{t('mfa.backupCodesTitle')}

+

{t('mfa.backupCodesWarning')}

+
+
+
+ {backupCodes.map((code, index) => ( + {code} + ))} +
+
+
+ + +
+
+ )} +
+ + {/* Disable Modal */} + setIsDisableModalOpen(false)} + title={t('mfa.disableTitle')} + className="!max-w-md" + > +
+

{t('mfa.disableDescription')}

+ setPassword(e.target.value)} + placeholder={t('common.account.password')} + /> +
+ + +
+
+
+
+ ) +} \ No newline at end of file diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 7360fdac44..245e463f3d 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -10,6 +10,7 @@ import { login } from '@/service/common' import Input from '@/app/components/base/input' import I18NContext from '@/context/i18n' import { noop } from 'lodash-es' +import MFAVerification from './mfa-verification' type MailAndPasswordAuthProps = { isInvite: boolean @@ -28,6 +29,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis const emailFromLink = decodeURIComponent(searchParams.get('email') || '') const [email, setEmail] = useState(emailFromLink) const [password, setPassword] = useState('') + const [showMFAVerification, setShowMFAVerification] = useState(false) const [isLoading, setIsLoading] = useState(false) const handleEmailPasswordLogin = async () => { @@ -67,7 +69,12 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis url: '/login', body: loginData, }) - if (res.result === 'success') { + console.log('Login response:', res) + if (res.code === 'mfa_required') { + console.log('MFA required, showing MFA verification screen') + setShowMFAVerification(true) + } + else if (res.result === 'success') { if (isInvite) { router.replace(`/signin/invite-settings?${searchParams.toString()}`) } @@ -104,6 +111,18 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis } } + if (showMFAVerification) { + return ( + + ) + } + return