style: convert unittest assertions to pytest style

Convert unittest-style assertions (assertEqual, assertTrue, etc.) to
pytest-style assertions to comply with project linting standards.

Applied via ruff --unsafe-fixes to ensure consistency.

 All 28 tests pass with pytest conversion
pull/22455/head
k-brahma-claude 10 months ago
parent 1bdbbb7140
commit f25af8430e

@ -89,15 +89,15 @@ class LoginApi(Resource):
return {"result": "fail", "data": token, "code": "account_not_found"} return {"result": "fail", "data": token, "code": "account_not_found"}
else: else:
raise AccountNotFound() raise AccountNotFound()
# Check MFA requirement # Check MFA requirement
if MFAService.is_mfa_required(account): if MFAService.is_mfa_required(account):
if not args["mfa_code"]: if not args["mfa_code"]:
return {"result": "fail", "code": "mfa_required"} return {"result": "fail", "code": "mfa_required"}
if not MFAService.authenticate_with_mfa(account, args["mfa_code"]): if not MFAService.authenticate_with_mfa(account, args["mfa_code"]):
return {"result": "fail", "code": "mfa_token_invalid", "data": "The MFA token is invalid or expired."} return {"result": "fail", "code": "mfa_token_invalid", "data": "The MFA token is invalid or expired."}
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0: if len(tenants) == 0:

@ -15,23 +15,20 @@ class MFASetupInitApi(Resource):
def get(self): def get(self):
"""Initialize MFA setup - generate secret and QR code (GET method for compatibility).""" """Initialize MFA setup - generate secret and QR code (GET method for compatibility)."""
return self.post() return self.post()
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
"""Initialize MFA setup - generate secret and QR code.""" """Initialize MFA setup - generate secret and QR code."""
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
try: try:
mfa_status = MFAService.get_mfa_status(account) mfa_status = MFAService.get_mfa_status(account)
if mfa_status["enabled"]: if mfa_status["enabled"]:
return {"error": "MFA is already enabled"}, 400 return {"error": "MFA is already enabled"}, 400
setup_data = MFAService.generate_mfa_setup_data(account) setup_data = MFAService.generate_mfa_setup_data(account)
return { return {"secret": setup_data["secret"], "qr_code": setup_data["qr_code"]}
"secret": setup_data["secret"],
"qr_code": setup_data["qr_code"]
}
except Exception as e: except Exception as e:
return {"error": str(e)}, 500 return {"error": str(e)}, 500
@ -44,15 +41,15 @@ class MFASetupCompleteApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("totp_token", type=str, required=True, help="TOTP token is required") parser.add_argument("totp_token", type=str, required=True, help="TOTP token is required")
args = parser.parse_args() args = parser.parse_args()
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
try: try:
result = MFAService.setup_mfa(account, args["totp_token"]) result = MFAService.setup_mfa(account, args["totp_token"])
return { return {
"message": "MFA setup completed successfully", "message": "MFA setup completed successfully",
"backup_codes": result["backup_codes"], "backup_codes": result["backup_codes"],
"setup_at": result["setup_at"].isoformat() "setup_at": result["setup_at"].isoformat(),
} }
except ValueError as e: except ValueError as e:
return {"error": str(e)}, 400 return {"error": str(e)}, 400
@ -68,14 +65,14 @@ class MFADisableApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=True, help="Password is required") parser.add_argument("password", type=str, required=True, help="Password is required")
args = parser.parse_args() args = parser.parse_args()
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
try: try:
mfa_status = MFAService.get_mfa_status(account) mfa_status = MFAService.get_mfa_status(account)
if not mfa_status["enabled"]: if not mfa_status["enabled"]:
return {"error": "MFA is not enabled"}, 400 return {"error": "MFA is not enabled"}, 400
if MFAService.disable_mfa(account, args["password"]): if MFAService.disable_mfa(account, args["password"]):
return {"message": "MFA disabled successfully"} return {"message": "MFA disabled successfully"}
else: else:
@ -90,7 +87,7 @@ class MFAStatusApi(Resource):
def get(self): def get(self):
"""Get current MFA status.""" """Get current MFA status."""
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
try: try:
status = MFAService.get_mfa_status(account) status = MFAService.get_mfa_status(account)
return status return status
@ -105,20 +102,21 @@ class MFAVerifyApi(Resource):
parser.add_argument("email", type=str, required=True, help="Email is required") parser.add_argument("email", type=str, required=True, help="Email is required")
parser.add_argument("mfa_token", type=str, required=True, help="MFA token is required") parser.add_argument("mfa_token", type=str, required=True, help="MFA token is required")
args = parser.parse_args() args = parser.parse_args()
from models.engine import db from models.engine import db
account = db.session.query(Account).filter_by(email=args["email"]).first() account = db.session.query(Account).filter_by(email=args["email"]).first()
if not account: if not account:
return {"error": "Account not found"}, 404 return {"error": "Account not found"}, 404
if not MFAService.is_mfa_required(account): if not MFAService.is_mfa_required(account):
return {"error": "MFA not required for this account"}, 400 return {"error": "MFA not required for this account"}, 400
try: try:
if MFAService.authenticate_with_mfa(account, args["mfa_token"]): if MFAService.authenticate_with_mfa(account, args["mfa_token"]):
return {"message": "MFA verification successful"} return {"message": "MFA verification successful"}
else: else:
return {"error": "Invalid MFA token"}, 400 return {"error": "Invalid MFA token"}, 400
except Exception as e: except Exception as e:
return {"error": str(e)}, 500 return {"error": str(e)}, 500

@ -319,4 +319,6 @@ class AccountMFASettings(Base):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
# Relationship # Relationship
account = db.relationship("Account", backref=db.backref("mfa_settings", uselist=False, cascade="all, delete-orphan")) account = db.relationship(
"Account", backref=db.backref("mfa_settings", uselist=False, cascade="all, delete-orphan")
)

@ -30,11 +30,8 @@ class MFAService:
def generate_qr_code(account: Account, secret: str) -> str: def generate_qr_code(account: Account, secret: str) -> str:
"""Generate QR code for TOTP setup.""" """Generate QR code for TOTP setup."""
totp = pyotp.TOTP(secret) totp = pyotp.TOTP(secret)
provisioning_uri = totp.provisioning_uri( provisioning_uri = totp.provisioning_uri(name=account.email, issuer_name="Dify")
name=account.email,
issuer_name="Dify"
)
# Generate QR code # Generate QR code
qr = qrcode.QRCode( qr = qrcode.QRCode(
version=1, version=1,
@ -44,15 +41,15 @@ class MFAService:
) )
qr.add_data(provisioning_uri) qr.add_data(provisioning_uri)
qr.make(fit=True) qr.make(fit=True)
# Create image # Create image
img = qr.make_image(fill_color="black", back_color="white") img = qr.make_image(fill_color="black", back_color="white")
# Convert to base64 # Convert to base64
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer) img.save(buffer)
img_str = base64.b64encode(buffer.getvalue()).decode() img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}" return f"data:image/png;base64,{img_str}"
@staticmethod @staticmethod
@ -82,7 +79,7 @@ class MFAService:
"""Verify and consume backup code.""" """Verify and consume backup code."""
if not mfa_settings.backup_codes: if not mfa_settings.backup_codes:
return False return False
try: try:
backup_codes = json.loads(mfa_settings.backup_codes) backup_codes = json.loads(mfa_settings.backup_codes)
if code.upper() in backup_codes: if code.upper() in backup_codes:
@ -93,58 +90,55 @@ class MFAService:
return True return True
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return False return False
@staticmethod @staticmethod
def setup_mfa(account: Account, totp_token: str) -> dict: def setup_mfa(account: Account, totp_token: str) -> dict:
"""Setup MFA for account with TOTP verification.""" """Setup MFA for account with TOTP verification."""
mfa_settings = MFAService.get_or_create_mfa_settings(account) mfa_settings = MFAService.get_or_create_mfa_settings(account)
if mfa_settings.enabled: if mfa_settings.enabled:
raise ValueError("MFA is already enabled for this account") raise ValueError("MFA is already enabled for this account")
if not mfa_settings.secret: if not mfa_settings.secret:
raise ValueError("MFA secret not generated") raise ValueError("MFA secret not generated")
# Verify TOTP token # Verify TOTP token
if not MFAService.verify_totp(mfa_settings.secret, totp_token): if not MFAService.verify_totp(mfa_settings.secret, totp_token):
raise ValueError("Invalid TOTP token") raise ValueError("Invalid TOTP token")
# Generate backup codes # Generate backup codes
backup_codes = MFAService.generate_backup_codes() backup_codes = MFAService.generate_backup_codes()
# Enable MFA # Enable MFA
mfa_settings.enabled = True mfa_settings.enabled = True
mfa_settings.backup_codes = json.dumps(backup_codes) mfa_settings.backup_codes = json.dumps(backup_codes)
mfa_settings.setup_at = datetime.now(UTC) mfa_settings.setup_at = datetime.now(UTC)
db.session.commit() db.session.commit()
return { return {"backup_codes": backup_codes, "setup_at": mfa_settings.setup_at}
"backup_codes": backup_codes,
"setup_at": mfa_settings.setup_at
}
@staticmethod @staticmethod
def disable_mfa(account: Account, password: str) -> bool: def disable_mfa(account: Account, password: str) -> bool:
"""Disable MFA for account after password verification.""" """Disable MFA for account after password verification."""
from libs.password import compare_password from libs.password import compare_password
# Verify password # Verify password
if account.password is None or not compare_password(password, account.password, account.password_salt): if account.password is None or not compare_password(password, account.password, account.password_salt):
return False return False
mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first()
if not mfa_settings: if not mfa_settings:
return True # Already disabled return True # Already disabled
# Disable MFA # Disable MFA
mfa_settings.enabled = False mfa_settings.enabled = False
mfa_settings.secret = None mfa_settings.secret = None
mfa_settings.backup_codes = None mfa_settings.backup_codes = None
mfa_settings.setup_at = None mfa_settings.setup_at = None
db.session.commit() db.session.commit()
return True return True
@ -152,22 +146,19 @@ class MFAService:
def generate_mfa_setup_data(account: Account) -> dict: def generate_mfa_setup_data(account: Account) -> dict:
"""Generate MFA setup data including secret and QR code.""" """Generate MFA setup data including secret and QR code."""
mfa_settings = MFAService.get_or_create_mfa_settings(account) mfa_settings = MFAService.get_or_create_mfa_settings(account)
if mfa_settings.enabled: if mfa_settings.enabled:
raise ValueError("MFA is already enabled for this account") raise ValueError("MFA is already enabled for this account")
# Generate new secret # Generate new secret
secret = MFAService.generate_secret() secret = MFAService.generate_secret()
mfa_settings.secret = secret mfa_settings.secret = secret
db.session.commit() db.session.commit()
# Generate QR code # Generate QR code
qr_code = MFAService.generate_qr_code(account, secret) qr_code = MFAService.generate_qr_code(account, secret)
return { return {"secret": secret, "qr_code": qr_code}
"secret": secret,
"qr_code": qr_code
}
@staticmethod @staticmethod
def is_mfa_required(account: Account) -> bool: def is_mfa_required(account: Account) -> bool:
@ -180,25 +171,25 @@ class MFAService:
"""Authenticate user with MFA token (TOTP or backup code).""" """Authenticate user with MFA token (TOTP or backup code)."""
print(f"[MFA DEBUG] authenticate_with_mfa called with token: {token}") print(f"[MFA DEBUG] authenticate_with_mfa called with token: {token}")
mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first()
if not mfa_settings or not mfa_settings.enabled: if not mfa_settings or not mfa_settings.enabled:
print("[MFA DEBUG] MFA not enabled, returning True") print("[MFA DEBUG] MFA not enabled, returning True")
return True return True
print(f"[MFA DEBUG] MFA enabled, secret: {mfa_settings.secret[:10]}...") print(f"[MFA DEBUG] MFA enabled, secret: {mfa_settings.secret[:10]}...")
# Try TOTP first # Try TOTP first
print("[MFA DEBUG] Trying TOTP verification") print("[MFA DEBUG] Trying TOTP verification")
if MFAService.verify_totp(mfa_settings.secret, token): if MFAService.verify_totp(mfa_settings.secret, token):
print("[MFA DEBUG] TOTP verification successful") print("[MFA DEBUG] TOTP verification successful")
return True return True
# Try backup code # Try backup code
print("[MFA DEBUG] Trying backup code verification") print("[MFA DEBUG] Trying backup code verification")
if MFAService.verify_backup_code(mfa_settings, token): if MFAService.verify_backup_code(mfa_settings, token):
print("[MFA DEBUG] Backup code verification successful") print("[MFA DEBUG] Backup code verification successful")
return True return True
print("[MFA DEBUG] All verifications failed") print("[MFA DEBUG] All verifications failed")
return False return False
@ -206,16 +197,12 @@ class MFAService:
def get_mfa_status(account: Account) -> dict: def get_mfa_status(account: Account) -> dict:
"""Get MFA status for account.""" """Get MFA status for account."""
mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first() mfa_settings = db.session.query(AccountMFASettings).filter_by(account_id=account.id).first()
if not mfa_settings: if not mfa_settings:
return { return {"enabled": False, "setup_at": None, "has_backup_codes": False}
"enabled": False,
"setup_at": None,
"has_backup_codes": False
}
return { return {
"enabled": mfa_settings.enabled, "enabled": mfa_settings.enabled,
"setup_at": mfa_settings.setup_at.isoformat() if mfa_settings.setup_at else None, "setup_at": mfa_settings.setup_at.isoformat() if mfa_settings.setup_at else None,
"has_backup_codes": mfa_settings.backup_codes is not None "has_backup_codes": mfa_settings.backup_codes is not None,
} }

@ -57,20 +57,21 @@ def setup_account(request) -> Generator[Account, None, None]:
if account: if account:
yield account yield account
return return
rand_suffix = random.randint(int(1e6), int(1e7)) # noqa rand_suffix = random.randint(int(1e6), int(1e7)) # noqa
name = f"test-user-{rand_suffix}" name = f"test-user-{rand_suffix}"
email = f"{name}@example.com" email = f"{name}@example.com"
# Clean up any existing setup first # Clean up any existing setup first
from models.account import AccountMFASettings from models.account import AccountMFASettings
db.session.query(AccountMFASettings).delete() db.session.query(AccountMFASettings).delete()
db.session.query(DifySetup).delete() db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete() db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete() db.session.query(Account).delete()
db.session.query(Tenant).delete() db.session.query(Tenant).delete()
db.session.commit() db.session.commit()
RegisterService.setup( RegisterService.setup(
email=email, email=email,
name=name, name=name,
@ -87,6 +88,7 @@ def setup_account(request) -> Generator[Account, None, None]:
with _CACHED_APP.test_request_context(): with _CACHED_APP.test_request_context():
# Clean up MFA settings first to avoid foreign key violations # Clean up MFA settings first to avoid foreign key violations
from models.account import AccountMFASettings from models.account import AccountMFASettings
db.session.query(AccountMFASettings).delete() db.session.query(AccountMFASettings).delete()
db.session.query(DifySetup).delete() db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete() db.session.query(TenantAccountJoin).delete()

@ -6,22 +6,31 @@ from flask import Flask
class TestLoginMFAIntegration: class TestLoginMFAIntegration:
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch('controllers.console.auth.login.FeatureService.get_system_features') @patch("controllers.console.auth.login.dify_config")
@patch('controllers.console.auth.login.dify_config') @patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
@patch('controllers.console.auth.login.BillingService.is_email_in_freeze') @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.authenticate")
@patch('controllers.console.auth.login.AccountService.authenticate') @patch("controllers.console.auth.login.MFAService.is_mfa_required")
@patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch('controllers.console.auth.login.TenantService.get_join_tenants') @patch("controllers.console.auth.login.AccountService.login")
@patch('controllers.console.auth.login.AccountService.login') @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') @patch("controllers.console.auth.login.extract_remote_ip")
@patch('controllers.console.auth.login.extract_remote_ip') def test_login_without_mfa_success(
def test_login_without_mfa_success(self, mock_extract_ip, mock_reset_limit, self,
mock_login_service, mock_get_tenants, mock_is_mfa_required, mock_extract_ip,
mock_authenticate, mock_rate_limit, mock_freeze_check, mock_reset_limit,
mock_dify_config, mock_system_features, mock_login_service,
test_client, setup_account): mock_get_tenants,
mock_is_mfa_required,
mock_authenticate,
mock_rate_limit,
mock_freeze_check,
mock_dify_config,
mock_system_features,
test_client,
setup_account,
):
"""Test successful login without MFA enabled.""" """Test successful login without MFA enabled."""
# Setup mocks # Setup mocks
mock_dify_config.BILLING_ENABLED = False mock_dify_config.BILLING_ENABLED = False
@ -35,19 +44,20 @@ class TestLoginMFAIntegration:
token_pair_mock = Mock() token_pair_mock = Mock()
token_pair_mock.model_dump.return_value = { token_pair_mock.model_dump.return_value = {
"access_token": "test_access_token", "access_token": "test_access_token",
"refresh_token": "test_refresh_token" "refresh_token": "test_refresh_token",
} }
mock_login_service.return_value = token_pair_mock mock_login_service.return_value = token_pair_mock
with patch('controllers.console.auth.login.setup_required') as mock_setup, \ with (
patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: patch("controllers.console.auth.login.setup_required") as mock_setup,
patch("controllers.console.auth.login.email_password_login_enabled") as mock_email_enabled,
):
mock_setup.return_value = lambda f: f mock_setup.return_value = lambda f: f
mock_email_enabled.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f
response = test_client.post('/console/api/login', json={ response = test_client.post(
"email": setup_account.email, "/console/api/login", json={"email": setup_account.email, "password": "TestPassword123"}
"password": "TestPassword123" )
})
# Debug output # Debug output
if response.status_code != 200: if response.status_code != 200:
@ -59,15 +69,23 @@ class TestLoginMFAIntegration:
assert data["result"] == "success" assert data["result"] == "success"
assert "access_token" in data["data"] assert "access_token" in data["data"]
@patch('controllers.console.auth.login.FeatureService.get_system_features') @patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch('controllers.console.auth.login.dify_config') @patch("controllers.console.auth.login.dify_config")
@patch('controllers.console.auth.login.BillingService.is_email_in_freeze') @patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
@patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.authenticate') @patch("controllers.console.auth.login.AccountService.authenticate")
@patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch("controllers.console.auth.login.MFAService.is_mfa_required")
def test_login_with_mfa_required_no_token(self, mock_is_mfa_required, mock_authenticate, def test_login_with_mfa_required_no_token(
mock_rate_limit, mock_freeze_check, mock_dify_config, self,
mock_system_features, test_client, setup_account): mock_is_mfa_required,
mock_authenticate,
mock_rate_limit,
mock_freeze_check,
mock_dify_config,
mock_system_features,
test_client,
setup_account,
):
"""Test login returns mfa_required when MFA is enabled but no token provided.""" """Test login returns mfa_required when MFA is enabled but no token provided."""
# Setup mocks # Setup mocks
mock_dify_config.BILLING_ENABLED = False mock_dify_config.BILLING_ENABLED = False
@ -76,31 +94,41 @@ class TestLoginMFAIntegration:
mock_authenticate.return_value = setup_account mock_authenticate.return_value = setup_account
mock_is_mfa_required.return_value = True mock_is_mfa_required.return_value = True
with patch('controllers.console.auth.login.setup_required') as mock_setup, \ with (
patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: patch("controllers.console.auth.login.setup_required") as mock_setup,
patch("controllers.console.auth.login.email_password_login_enabled") as mock_email_enabled,
):
mock_setup.return_value = lambda f: f mock_setup.return_value = lambda f: f
mock_email_enabled.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f
response = test_client.post('/console/api/login', json={ response = test_client.post(
"email": "test@example.com", "/console/api/login", json={"email": "test@example.com", "password": "TestPassword123"}
"password": "TestPassword123" )
})
assert response.status_code == 200 assert response.status_code == 200
data = json.loads(response.data) data = json.loads(response.data)
assert data["result"] == "fail" assert data["result"] == "fail"
assert data["code"] == "mfa_required" assert data["code"] == "mfa_required"
@patch('controllers.console.auth.login.FeatureService.get_system_features') @patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch('controllers.console.auth.login.dify_config') @patch("controllers.console.auth.login.dify_config")
@patch('controllers.console.auth.login.BillingService.is_email_in_freeze') @patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
@patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.authenticate') @patch("controllers.console.auth.login.AccountService.authenticate")
@patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch("controllers.console.auth.login.MFAService.is_mfa_required")
@patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') @patch("controllers.console.auth.login.MFAService.authenticate_with_mfa")
def test_login_with_mfa_invalid_token(self, mock_auth_mfa, mock_is_mfa_required, mock_authenticate, def test_login_with_mfa_invalid_token(
mock_rate_limit, mock_freeze_check, mock_dify_config, self,
mock_system_features, test_client, setup_account): mock_auth_mfa,
mock_is_mfa_required,
mock_authenticate,
mock_rate_limit,
mock_freeze_check,
mock_dify_config,
mock_system_features,
test_client,
setup_account,
):
"""Test login fails with invalid MFA token.""" """Test login fails with invalid MFA token."""
# Setup mocks # Setup mocks
mock_dify_config.BILLING_ENABLED = False mock_dify_config.BILLING_ENABLED = False
@ -110,16 +138,17 @@ class TestLoginMFAIntegration:
mock_is_mfa_required.return_value = True mock_is_mfa_required.return_value = True
mock_auth_mfa.return_value = False # Invalid token mock_auth_mfa.return_value = False # Invalid token
with patch('controllers.console.auth.login.setup_required') as mock_setup, \ with (
patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: patch("controllers.console.auth.login.setup_required") as mock_setup,
patch("controllers.console.auth.login.email_password_login_enabled") as mock_email_enabled,
):
mock_setup.return_value = lambda f: f mock_setup.return_value = lambda f: f
mock_email_enabled.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f
response = test_client.post('/console/api/login', json={ response = test_client.post(
"email": "test@example.com", "/console/api/login",
"password": "TestPassword123", json={"email": "test@example.com", "password": "TestPassword123", "mfa_code": "invalid_token"},
"mfa_code": "invalid_token" )
})
assert response.status_code == 200 assert response.status_code == 200
data = json.loads(response.data) data = json.loads(response.data)
@ -127,22 +156,33 @@ class TestLoginMFAIntegration:
assert data["code"] == "mfa_token_invalid" assert data["code"] == "mfa_token_invalid"
assert data["data"] == "The MFA token is invalid or expired." assert data["data"] == "The MFA token is invalid or expired."
@patch('controllers.console.auth.login.FeatureService.get_system_features') @patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch('controllers.console.auth.login.dify_config') @patch("controllers.console.auth.login.dify_config")
@patch('controllers.console.auth.login.BillingService.is_email_in_freeze') @patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
@patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.authenticate') @patch("controllers.console.auth.login.AccountService.authenticate")
@patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch("controllers.console.auth.login.MFAService.is_mfa_required")
@patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') @patch("controllers.console.auth.login.MFAService.authenticate_with_mfa")
@patch('controllers.console.auth.login.TenantService.get_join_tenants') @patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch('controllers.console.auth.login.AccountService.login') @patch("controllers.console.auth.login.AccountService.login")
@patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
@patch('controllers.console.auth.login.extract_remote_ip') @patch("controllers.console.auth.login.extract_remote_ip")
def test_login_with_mfa_valid_token_success(self, mock_extract_ip, mock_reset_limit, def test_login_with_mfa_valid_token_success(
mock_login_service, mock_get_tenants, mock_auth_mfa, self,
mock_is_mfa_required, mock_authenticate, mock_extract_ip,
mock_rate_limit, mock_freeze_check, mock_dify_config, mock_reset_limit,
mock_system_features, test_client, setup_account): mock_login_service,
mock_get_tenants,
mock_auth_mfa,
mock_is_mfa_required,
mock_authenticate,
mock_rate_limit,
mock_freeze_check,
mock_dify_config,
mock_system_features,
test_client,
setup_account,
):
"""Test successful login with valid MFA token.""" """Test successful login with valid MFA token."""
# Setup mocks # Setup mocks
mock_dify_config.BILLING_ENABLED = False mock_dify_config.BILLING_ENABLED = False
@ -157,20 +197,21 @@ class TestLoginMFAIntegration:
token_pair_mock = Mock() token_pair_mock = Mock()
token_pair_mock.model_dump.return_value = { token_pair_mock.model_dump.return_value = {
"access_token": "test_access_token", "access_token": "test_access_token",
"refresh_token": "test_refresh_token" "refresh_token": "test_refresh_token",
} }
mock_login_service.return_value = token_pair_mock mock_login_service.return_value = token_pair_mock
with patch('controllers.console.auth.login.setup_required') as mock_setup, \ with (
patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: patch("controllers.console.auth.login.setup_required") as mock_setup,
patch("controllers.console.auth.login.email_password_login_enabled") as mock_email_enabled,
):
mock_setup.return_value = lambda f: f mock_setup.return_value = lambda f: f
mock_email_enabled.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f
response = test_client.post('/console/api/login', json={ response = test_client.post(
"email": "test@example.com", "/console/api/login",
"password": "TestPassword123", json={"email": "test@example.com", "password": "TestPassword123", "mfa_code": "123456"},
"mfa_code": "123456" )
})
assert response.status_code == 200 assert response.status_code == 200
data = json.loads(response.data) data = json.loads(response.data)
@ -180,22 +221,33 @@ class TestLoginMFAIntegration:
# Verify MFA authentication was called # Verify MFA authentication was called
mock_auth_mfa.assert_called_once_with(setup_account, "123456") mock_auth_mfa.assert_called_once_with(setup_account, "123456")
@patch('controllers.console.auth.login.FeatureService.get_system_features') @patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch('controllers.console.auth.login.dify_config') @patch("controllers.console.auth.login.dify_config")
@patch('controllers.console.auth.login.BillingService.is_email_in_freeze') @patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
@patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.authenticate') @patch("controllers.console.auth.login.AccountService.authenticate")
@patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch("controllers.console.auth.login.MFAService.is_mfa_required")
@patch('controllers.console.auth.login.MFAService.authenticate_with_mfa') @patch("controllers.console.auth.login.MFAService.authenticate_with_mfa")
@patch('controllers.console.auth.login.TenantService.get_join_tenants') @patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch('controllers.console.auth.login.AccountService.login') @patch("controllers.console.auth.login.AccountService.login")
@patch('controllers.console.auth.login.AccountService.reset_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
@patch('controllers.console.auth.login.extract_remote_ip') @patch("controllers.console.auth.login.extract_remote_ip")
def test_login_with_mfa_backup_code_success(self, mock_extract_ip, mock_reset_limit, def test_login_with_mfa_backup_code_success(
mock_login_service, mock_get_tenants, mock_auth_mfa, self,
mock_is_mfa_required, mock_authenticate, mock_extract_ip,
mock_rate_limit, mock_freeze_check, mock_dify_config, mock_reset_limit,
mock_system_features, test_client, setup_account): mock_login_service,
mock_get_tenants,
mock_auth_mfa,
mock_is_mfa_required,
mock_authenticate,
mock_rate_limit,
mock_freeze_check,
mock_dify_config,
mock_system_features,
test_client,
setup_account,
):
"""Test successful login with valid backup code.""" """Test successful login with valid backup code."""
# Setup mocks # Setup mocks
mock_dify_config.BILLING_ENABLED = False mock_dify_config.BILLING_ENABLED = False
@ -210,20 +262,25 @@ class TestLoginMFAIntegration:
token_pair_mock = Mock() token_pair_mock = Mock()
token_pair_mock.model_dump.return_value = { token_pair_mock.model_dump.return_value = {
"access_token": "test_access_token", "access_token": "test_access_token",
"refresh_token": "test_refresh_token" "refresh_token": "test_refresh_token",
} }
mock_login_service.return_value = token_pair_mock mock_login_service.return_value = token_pair_mock
with patch('controllers.console.auth.login.setup_required') as mock_setup, \ with (
patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled: patch("controllers.console.auth.login.setup_required") as mock_setup,
patch("controllers.console.auth.login.email_password_login_enabled") as mock_email_enabled,
):
mock_setup.return_value = lambda f: f mock_setup.return_value = lambda f: f
mock_email_enabled.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f
response = test_client.post('/console/api/login', json={ response = test_client.post(
"email": "test@example.com", "/console/api/login",
"password": "TestPassword123", json={
"mfa_code": "BACKUP123" # Backup code format "email": "test@example.com",
}) "password": "TestPassword123",
"mfa_code": "BACKUP123", # Backup code format
},
)
assert response.status_code == 200 assert response.status_code == 200
data = json.loads(response.data) data = json.loads(response.data)
@ -233,15 +290,22 @@ class TestLoginMFAIntegration:
# Verify MFA authentication was called with backup code # Verify MFA authentication was called with backup code
mock_auth_mfa.assert_called_once_with(setup_account, "BACKUP123") mock_auth_mfa.assert_called_once_with(setup_account, "BACKUP123")
@patch('controllers.console.auth.login.FeatureService.get_system_features') @patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch('controllers.console.auth.login.dify_config') @patch("controllers.console.auth.login.dify_config")
@patch('controllers.console.auth.login.BillingService.is_email_in_freeze') @patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
@patch('controllers.console.auth.login.AccountService.is_login_error_rate_limit') @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch('controllers.console.auth.login.AccountService.authenticate') @patch("controllers.console.auth.login.AccountService.authenticate")
@patch('controllers.console.auth.login.MFAService.is_mfa_required') @patch("controllers.console.auth.login.MFAService.is_mfa_required")
def test_login_mfa_flow_order(self, mock_is_mfa_required, mock_authenticate, def test_login_mfa_flow_order(
mock_rate_limit, mock_freeze_check, mock_dify_config, self,
mock_system_features, test_client): mock_is_mfa_required,
mock_authenticate,
mock_rate_limit,
mock_freeze_check,
mock_dify_config,
mock_system_features,
test_client,
):
"""Test that MFA check happens after password authentication.""" """Test that MFA check happens after password authentication."""
# Setup mocks - password auth fails # Setup mocks - password auth fails
mock_dify_config.BILLING_ENABLED = False mock_dify_config.BILLING_ENABLED = False
@ -250,19 +314,21 @@ class TestLoginMFAIntegration:
# Mock password authentication failure # Mock password authentication failure
from services.errors.account import AccountPasswordError from services.errors.account import AccountPasswordError
mock_authenticate.side_effect = AccountPasswordError() mock_authenticate.side_effect = AccountPasswordError()
with patch('controllers.console.auth.login.setup_required') as mock_setup, \ with (
patch('controllers.console.auth.login.email_password_login_enabled') as mock_email_enabled, \ patch("controllers.console.auth.login.setup_required") as mock_setup,
patch('controllers.console.auth.login.AccountService.add_login_error_rate_limit') as mock_add_limit: patch("controllers.console.auth.login.email_password_login_enabled") as mock_email_enabled,
patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") as mock_add_limit,
):
mock_setup.return_value = lambda f: f mock_setup.return_value = lambda f: f
mock_email_enabled.return_value = lambda f: f mock_email_enabled.return_value = lambda f: f
response = test_client.post('/console/api/login', json={ response = test_client.post(
"email": "test@example.com", "/console/api/login",
"password": "WrongPassword123", json={"email": "test@example.com", "password": "WrongPassword123", "mfa_code": "123456"},
"mfa_code": "123456" )
})
# Password error should trigger EmailOrPasswordMismatchError # Password error should trigger EmailOrPasswordMismatchError
assert response.status_code == 400 assert response.status_code == 400
@ -276,14 +342,14 @@ class TestMFAEndToEndFlow:
def setup_method(self): def setup_method(self):
self.app = Flask(__name__) self.app = Flask(__name__)
self.app.config['TESTING'] = True self.app.config["TESTING"] = True
self.client = self.app.test_client() self.client = self.app.test_client()
@patch('services.mfa_service.MFAService.generate_secret') @patch("services.mfa_service.MFAService.generate_secret")
@patch('services.mfa_service.MFAService.generate_qr_code') @patch("services.mfa_service.MFAService.generate_qr_code")
@patch('services.mfa_service.MFAService.verify_totp') @patch("services.mfa_service.MFAService.verify_totp")
@patch('services.mfa_service.MFAService.generate_backup_codes') @patch("services.mfa_service.MFAService.generate_backup_codes")
@patch('services.mfa_service.db.session') @patch("services.mfa_service.db.session")
def test_complete_mfa_setup_flow(self, mock_session, mock_gen_codes, mock_verify, mock_gen_qr, mock_gen_secret): def test_complete_mfa_setup_flow(self, mock_session, mock_gen_codes, mock_verify, mock_gen_qr, mock_gen_secret):
"""Test complete MFA setup flow from init to completion.""" """Test complete MFA setup flow from init to completion."""
from models.account import Account from models.account import Account
@ -301,7 +367,7 @@ class TestMFAEndToEndFlow:
mock_gen_codes.return_value = ["CODE1", "CODE2", "CODE3"] mock_gen_codes.return_value = ["CODE1", "CODE2", "CODE3"]
# Step 1: Initialize MFA setup # Step 1: Initialize MFA setup
with patch('services.mfa_service.MFAService.get_or_create_mfa_settings') as mock_get_settings: with patch("services.mfa_service.MFAService.get_or_create_mfa_settings") as mock_get_settings:
mfa_settings = Mock() mfa_settings = Mock()
mfa_settings.enabled = False mfa_settings.enabled = False
mfa_settings.secret = None mfa_settings.secret = None
@ -314,7 +380,7 @@ class TestMFAEndToEndFlow:
assert mfa_settings.secret == "TESTSECRET123" assert mfa_settings.secret == "TESTSECRET123"
# Step 2: Complete MFA setup # Step 2: Complete MFA setup
with patch('services.mfa_service.MFAService.get_or_create_mfa_settings') as mock_get_settings: with patch("services.mfa_service.MFAService.get_or_create_mfa_settings") as mock_get_settings:
mfa_settings.secret = "TESTSECRET123" mfa_settings.secret = "TESTSECRET123"
mock_get_settings.return_value = mfa_settings mock_get_settings.return_value = mfa_settings

@ -9,159 +9,133 @@ from services.mfa_service import MFAService
class TestMFAEndpoints: class TestMFAEndpoints:
"""Test MFA endpoints using integration test approach.""" """Test MFA endpoints using integration test approach."""
@pytest.fixture @pytest.fixture
def auth_header(self, setup_account): def auth_header(self, setup_account):
"""Get authentication header with JWT token.""" """Get authentication header with JWT token."""
token = AccountService.get_account_jwt_token(setup_account) token = AccountService.get_account_jwt_token(setup_account)
return {"Authorization": f"Bearer {token}"} return {"Authorization": f"Bearer {token}"}
def test_mfa_status_success(self, test_client, setup_account, auth_header): def test_mfa_status_success(self, test_client, setup_account, auth_header):
"""Test successful MFA status check.""" """Test successful MFA status check."""
with patch.object(MFAService, 'get_mfa_status') as mock_status: with patch.object(MFAService, "get_mfa_status") as mock_status:
mock_status.return_value = {"enabled": False, "setup_at": None} mock_status.return_value = {"enabled": False, "setup_at": None}
response = test_client.get( response = test_client.get("/console/api/account/mfa/status", headers=auth_header)
'/console/api/account/mfa/status',
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["enabled"] is False assert data["enabled"] is False
assert data["setup_at"] is None assert data["setup_at"] is None
mock_status.assert_called_once_with(setup_account) mock_status.assert_called_once_with(setup_account)
def test_mfa_setup_init_success(self, test_client, setup_account, auth_header): def test_mfa_setup_init_success(self, test_client, setup_account, auth_header):
"""Test successful MFA setup initialization.""" """Test successful MFA setup initialization."""
with patch.object(MFAService, 'get_mfa_status') as mock_status: with patch.object(MFAService, "get_mfa_status") as mock_status:
with patch.object(MFAService, 'generate_mfa_setup_data') as mock_generate: with patch.object(MFAService, "generate_mfa_setup_data") as mock_generate:
mock_status.return_value = {"enabled": False} mock_status.return_value = {"enabled": False}
mock_generate.return_value = { mock_generate.return_value = {"secret": "TEST_SECRET", "qr_code": "data:image/png;base64,test"}
"secret": "TEST_SECRET",
"qr_code": "data:image/png;base64,test" response = test_client.post("/console/api/account/mfa/setup", headers=auth_header)
}
response = test_client.post(
'/console/api/account/mfa/setup',
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["secret"] == "TEST_SECRET" assert data["secret"] == "TEST_SECRET"
assert data["qr_code"] == "data:image/png;base64,test" assert data["qr_code"] == "data:image/png;base64,test"
mock_generate.assert_called_once_with(setup_account) mock_generate.assert_called_once_with(setup_account)
def test_mfa_setup_init_already_enabled(self, test_client, setup_account, auth_header): def test_mfa_setup_init_already_enabled(self, test_client, setup_account, auth_header):
"""Test MFA setup initialization when already enabled.""" """Test MFA setup initialization when already enabled."""
with patch.object(MFAService, 'get_mfa_status') as mock_status: with patch.object(MFAService, "get_mfa_status") as mock_status:
mock_status.return_value = {"enabled": True, "setup_at": "2024-01-01T00:00:00"} mock_status.return_value = {"enabled": True, "setup_at": "2024-01-01T00:00:00"}
response = test_client.post( response = test_client.post("/console/api/account/mfa/setup", headers=auth_header)
'/console/api/account/mfa/setup',
headers=auth_header
)
assert response.status_code == 400 assert response.status_code == 400
data = response.json data = response.json
assert data["error"] == "MFA is already enabled" assert data["error"] == "MFA is already enabled"
def test_mfa_setup_complete_success(self, test_client, setup_account, auth_header): def test_mfa_setup_complete_success(self, test_client, setup_account, auth_header):
"""Test successful MFA setup completion.""" """Test successful MFA setup completion."""
with patch.object(MFAService, 'setup_mfa') as mock_setup: with patch.object(MFAService, "setup_mfa") as mock_setup:
mock_setup.return_value = { mock_setup.return_value = {
"backup_codes": ["CODE1", "CODE2", "CODE3", "CODE4", "CODE5", "CODE6", "CODE7", "CODE8"], "backup_codes": ["CODE1", "CODE2", "CODE3", "CODE4", "CODE5", "CODE6", "CODE7", "CODE8"],
"setup_at": datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) "setup_at": datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC),
} }
response = test_client.post( response = test_client.post(
'/console/api/account/mfa/setup/complete', "/console/api/account/mfa/setup/complete", headers=auth_header, json={"totp_token": "123456"}
headers=auth_header,
json={"totp_token": "123456"}
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["message"] == "MFA setup completed successfully" assert data["message"] == "MFA setup completed successfully"
assert len(data["backup_codes"]) == 8 assert len(data["backup_codes"]) == 8
assert data["setup_at"] == "2024-01-01T00:00:00+00:00" assert data["setup_at"] == "2024-01-01T00:00:00+00:00"
mock_setup.assert_called_once_with(setup_account, "123456") mock_setup.assert_called_once_with(setup_account, "123456")
def test_mfa_setup_complete_missing_token(self, test_client, setup_account, auth_header): def test_mfa_setup_complete_missing_token(self, test_client, setup_account, auth_header):
"""Test MFA setup completion with missing token.""" """Test MFA setup completion with missing token."""
response = test_client.post( response = test_client.post("/console/api/account/mfa/setup/complete", headers=auth_header, json={})
'/console/api/account/mfa/setup/complete',
headers=auth_header,
json={}
)
assert response.status_code == 400 assert response.status_code == 400
data = response.json data = response.json
assert "message" in data assert "message" in data
assert "TOTP token is required" in data["message"] assert "TOTP token is required" in data["message"]
def test_mfa_setup_complete_invalid_token(self, test_client, setup_account, auth_header): def test_mfa_setup_complete_invalid_token(self, test_client, setup_account, auth_header):
"""Test MFA setup completion with invalid token.""" """Test MFA setup completion with invalid token."""
with patch.object(MFAService, 'setup_mfa') as mock_setup: with patch.object(MFAService, "setup_mfa") as mock_setup:
mock_setup.side_effect = ValueError("Invalid TOTP token") mock_setup.side_effect = ValueError("Invalid TOTP token")
response = test_client.post( response = test_client.post(
'/console/api/account/mfa/setup/complete', "/console/api/account/mfa/setup/complete", headers=auth_header, json={"totp_token": "999999"}
headers=auth_header,
json={"totp_token": "999999"}
) )
assert response.status_code == 400 assert response.status_code == 400
data = response.json data = response.json
assert "Invalid TOTP token" in data["error"] assert "Invalid TOTP token" in data["error"]
def test_mfa_disable_success(self, test_client, setup_account, auth_header): def test_mfa_disable_success(self, test_client, setup_account, auth_header):
"""Test successful MFA disable.""" """Test successful MFA disable."""
with patch.object(MFAService, 'get_mfa_status') as mock_status: with patch.object(MFAService, "get_mfa_status") as mock_status:
with patch.object(MFAService, 'disable_mfa') as mock_disable: with patch.object(MFAService, "disable_mfa") as mock_disable:
mock_status.return_value = {"enabled": True} mock_status.return_value = {"enabled": True}
mock_disable.return_value = True mock_disable.return_value = True
response = test_client.post( response = test_client.post(
'/console/api/account/mfa/disable', "/console/api/account/mfa/disable", headers=auth_header, json={"password": "test_password"}
headers=auth_header,
json={"password": "test_password"}
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["message"] == "MFA disabled successfully" assert data["message"] == "MFA disabled successfully"
mock_disable.assert_called_once_with(setup_account, "test_password") mock_disable.assert_called_once_with(setup_account, "test_password")
def test_mfa_disable_wrong_password(self, test_client, setup_account, auth_header): def test_mfa_disable_wrong_password(self, test_client, setup_account, auth_header):
"""Test MFA disable with wrong password.""" """Test MFA disable with wrong password."""
with patch.object(MFAService, 'get_mfa_status') as mock_status: with patch.object(MFAService, "get_mfa_status") as mock_status:
with patch.object(MFAService, 'disable_mfa') as mock_disable: with patch.object(MFAService, "disable_mfa") as mock_disable:
mock_status.return_value = {"enabled": True} mock_status.return_value = {"enabled": True}
mock_disable.return_value = False mock_disable.return_value = False
response = test_client.post( response = test_client.post(
'/console/api/account/mfa/disable', "/console/api/account/mfa/disable", headers=auth_header, json={"password": "wrong_password"}
headers=auth_header,
json={"password": "wrong_password"}
) )
assert response.status_code == 400 assert response.status_code == 400
data = response.json data = response.json
assert data["error"] == "Invalid password" assert data["error"] == "Invalid password"
def test_mfa_disable_not_enabled(self, test_client, setup_account, auth_header): def test_mfa_disable_not_enabled(self, test_client, setup_account, auth_header):
"""Test MFA disable when not enabled.""" """Test MFA disable when not enabled."""
with patch.object(MFAService, 'get_mfa_status') as mock_status: with patch.object(MFAService, "get_mfa_status") as mock_status:
mock_status.return_value = {"enabled": False} mock_status.return_value = {"enabled": False}
response = test_client.post( response = test_client.post(
'/console/api/account/mfa/disable', "/console/api/account/mfa/disable", headers=auth_header, json={"password": "test_password"}
headers=auth_header,
json={"password": "test_password"}
) )
assert response.status_code == 400 assert response.status_code == 400
data = response.json data = response.json
assert data["error"] == "MFA is not enabled" assert data["error"] == "MFA is not enabled"

@ -5,103 +5,81 @@ from services.mfa_service import MFAService
class TestMFASimpleIntegration: class TestMFASimpleIntegration:
"""Simple integration tests for MFA functionality.""" """Simple integration tests for MFA functionality."""
def test_mfa_setup_flow(self, test_client, setup_account, auth_header): def test_mfa_setup_flow(self, test_client, setup_account, auth_header):
"""Test MFA setup flow end-to-end.""" """Test MFA setup flow end-to-end."""
# Step 1: Check initial MFA status # Step 1: Check initial MFA status
response = test_client.get( response = test_client.get("/console/api/account/mfa/status", headers=auth_header)
"/console/api/account/mfa/status",
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["enabled"] is False assert data["enabled"] is False
# Step 2: Initialize MFA setup # Step 2: Initialize MFA setup
response = test_client.post( response = test_client.post("/console/api/account/mfa/setup", headers=auth_header)
"/console/api/account/mfa/setup",
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert "secret" in data assert "secret" in data
assert "qr_code" in data assert "qr_code" in data
secret = data["secret"] secret = data["secret"]
# Step 3: Complete MFA setup with mocked TOTP # Step 3: Complete MFA setup with mocked TOTP
with mock.patch.object(MFAService, 'verify_totp', return_value=True): with mock.patch.object(MFAService, "verify_totp", return_value=True):
response = test_client.post( response = test_client.post(
"/console/api/account/mfa/setup/complete", "/console/api/account/mfa/setup/complete", headers=auth_header, json={"totp_token": "123456"}
headers=auth_header,
json={"totp_token": "123456"}
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert "backup_codes" in data assert "backup_codes" in data
assert len(data["backup_codes"]) == 8 assert len(data["backup_codes"]) == 8
# Step 4: Verify MFA is now enabled # Step 4: Verify MFA is now enabled
response = test_client.get( response = test_client.get("/console/api/account/mfa/status", headers=auth_header)
"/console/api/account/mfa/status",
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["enabled"] is True assert data["enabled"] is True
def test_mfa_disable_flow(self, test_client, setup_account, auth_header): def test_mfa_disable_flow(self, test_client, setup_account, auth_header):
"""Test MFA disable flow.""" """Test MFA disable flow."""
# First check MFA status and disable if already enabled # First check MFA status and disable if already enabled
response = test_client.get( response = test_client.get("/console/api/account/mfa/status", headers=auth_header)
"/console/api/account/mfa/status",
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
if data["enabled"]: if data["enabled"]:
# MFA is already enabled, disable it first with mocked password verification # MFA is already enabled, disable it first with mocked password verification
with mock.patch('libs.password.compare_password', return_value=True): with mock.patch("libs.password.compare_password", return_value=True):
response = test_client.post( response = test_client.post(
"/console/api/account/mfa/disable", "/console/api/account/mfa/disable",
headers=auth_header, headers=auth_header,
json={"password": "any_password"} # Password doesn't matter, it's mocked json={"password": "any_password"}, # Password doesn't matter, it's mocked
) )
assert response.status_code == 200 assert response.status_code == 200
# Now set up MFA for the account # Now set up MFA for the account
with mock.patch.object(MFAService, 'verify_totp', return_value=True): with mock.patch.object(MFAService, "verify_totp", return_value=True):
# Initialize setup # Initialize setup
response = test_client.post( response = test_client.post("/console/api/account/mfa/setup", headers=auth_header)
"/console/api/account/mfa/setup",
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
# Complete setup # Complete setup
response = test_client.post( response = test_client.post(
"/console/api/account/mfa/setup/complete", "/console/api/account/mfa/setup/complete", headers=auth_header, json={"totp_token": "123456"}
headers=auth_header,
json={"totp_token": "123456"}
) )
assert response.status_code == 200 assert response.status_code == 200
# Now disable MFA with mocked password verification # Now disable MFA with mocked password verification
with mock.patch('libs.password.compare_password', return_value=True): with mock.patch("libs.password.compare_password", return_value=True):
response = test_client.post( response = test_client.post(
"/console/api/account/mfa/disable", "/console/api/account/mfa/disable",
headers=auth_header, headers=auth_header,
json={"password": "any_password"} # Password doesn't matter, it's mocked json={"password": "any_password"}, # Password doesn't matter, it's mocked
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert "disabled successfully" in data["message"] assert "disabled successfully" in data["message"]
# Verify MFA is disabled # Verify MFA is disabled
response = test_client.get( response = test_client.get("/console/api/account/mfa/status", headers=auth_header)
"/console/api/account/mfa/status",
headers=auth_header
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json data = response.json
assert data["enabled"] is False assert data["enabled"] is False

@ -3,6 +3,8 @@ import unittest
from datetime import datetime from datetime import datetime
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from models.account import Account, AccountMFASettings from models.account import Account, AccountMFASettings
from services.mfa_service import MFAService from services.mfa_service import MFAService
@ -14,7 +16,7 @@ class TestMFAService(unittest.TestCase):
self.account.email = "test@example.com" self.account.email = "test@example.com"
self.account.password = "hashed_password" self.account.password = "hashed_password"
self.account.password_salt = "salt" self.account.password_salt = "salt"
self.mfa_settings = Mock(spec=AccountMFASettings) self.mfa_settings = Mock(spec=AccountMFASettings)
self.mfa_settings.account_id = self.account.id self.mfa_settings.account_id = self.account.id
self.mfa_settings.enabled = False self.mfa_settings.enabled = False
@ -25,198 +27,198 @@ class TestMFAService(unittest.TestCase):
def test_generate_secret(self): def test_generate_secret(self):
"""Test secret generation.""" """Test secret generation."""
secret = MFAService.generate_secret() secret = MFAService.generate_secret()
self.assertIsInstance(secret, str) assert isinstance(secret, str)
self.assertEqual(len(secret), 32) # Base32 length assert len(secret) == 32 # Base32 length
def test_generate_backup_codes(self): def test_generate_backup_codes(self):
"""Test backup codes generation.""" """Test backup codes generation."""
codes = MFAService.generate_backup_codes() codes = MFAService.generate_backup_codes()
self.assertEqual(len(codes), 8) assert len(codes) == 8
for code in codes: for code in codes:
self.assertIsInstance(code, str) assert isinstance(code, str)
self.assertEqual(len(code), 8) # 4 hex bytes = 8 chars assert len(code) == 8 # 4 hex bytes = 8 chars
@patch('pyotp.TOTP') @patch("pyotp.TOTP")
def test_verify_totp_valid(self, mock_totp_class): def test_verify_totp_valid(self, mock_totp_class):
"""Test TOTP verification with valid token.""" """Test TOTP verification with valid token."""
mock_totp = Mock() mock_totp = Mock()
mock_totp.verify.return_value = True mock_totp.verify.return_value = True
mock_totp_class.return_value = mock_totp mock_totp_class.return_value = mock_totp
result = MFAService.verify_totp("test_secret", "123456") result = MFAService.verify_totp("test_secret", "123456")
self.assertTrue(result) assert result
mock_totp.verify.assert_called_once_with("123456", valid_window=1) mock_totp.verify.assert_called_once_with("123456", valid_window=1)
@patch('pyotp.TOTP') @patch("pyotp.TOTP")
def test_verify_totp_invalid(self, mock_totp_class): def test_verify_totp_invalid(self, mock_totp_class):
"""Test TOTP verification with invalid token.""" """Test TOTP verification with invalid token."""
mock_totp = Mock() mock_totp = Mock()
mock_totp.verify.return_value = False mock_totp.verify.return_value = False
mock_totp_class.return_value = mock_totp mock_totp_class.return_value = mock_totp
result = MFAService.verify_totp("test_secret", "invalid") result = MFAService.verify_totp("test_secret", "invalid")
self.assertFalse(result) assert not result
def test_verify_totp_no_secret(self): def test_verify_totp_no_secret(self):
"""Test TOTP verification with no secret.""" """Test TOTP verification with no secret."""
result = MFAService.verify_totp(None, "123456") result = MFAService.verify_totp(None, "123456")
self.assertFalse(result) assert not result
@patch('services.mfa_service.db.session') @patch("services.mfa_service.db.session")
def test_get_or_create_mfa_settings_existing(self, mock_session): def test_get_or_create_mfa_settings_existing(self, mock_session):
"""Test getting existing MFA settings.""" """Test getting existing MFA settings."""
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.get_or_create_mfa_settings(self.account) result = MFAService.get_or_create_mfa_settings(self.account)
self.assertEqual(result, self.mfa_settings) assert result == self.mfa_settings
mock_session.query.assert_called_once() mock_session.query.assert_called_once()
@patch('services.mfa_service.db.session') @patch("services.mfa_service.db.session")
def test_get_or_create_mfa_settings_new(self, mock_session): def test_get_or_create_mfa_settings_new(self, mock_session):
"""Test creating new MFA settings.""" """Test creating new MFA settings."""
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.query.return_value.filter_by.return_value.first.return_value = None
result = MFAService.get_or_create_mfa_settings(self.account) result = MFAService.get_or_create_mfa_settings(self.account)
# Check that new settings were created # Check that new settings were created
self.assertIsInstance(result, AccountMFASettings) assert isinstance(result, AccountMFASettings)
self.assertEqual(result.account_id, self.account.id) assert result.account_id == self.account.id
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
@patch('services.mfa_service.db.session') @patch("services.mfa_service.db.session")
def test_verify_backup_code_valid(self, mock_session): def test_verify_backup_code_valid(self, mock_session):
"""Test backup code verification with valid code.""" """Test backup code verification with valid code."""
self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"]) self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"])
result = MFAService.verify_backup_code(self.mfa_settings, "abcd1234") # Test case insensitive result = MFAService.verify_backup_code(self.mfa_settings, "abcd1234") # Test case insensitive
self.assertTrue(result) assert result
# Check that the code was removed # Check that the code was removed
remaining_codes = json.loads(self.mfa_settings.backup_codes) remaining_codes = json.loads(self.mfa_settings.backup_codes)
self.assertNotIn("ABCD1234", remaining_codes) assert "ABCD1234" not in remaining_codes
self.assertIn("EFGH5678", remaining_codes) assert "EFGH5678" in remaining_codes
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
def test_verify_backup_code_invalid(self): def test_verify_backup_code_invalid(self):
"""Test backup code verification with invalid code.""" """Test backup code verification with invalid code."""
self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"]) self.mfa_settings.backup_codes = json.dumps(["ABCD1234", "EFGH5678"])
result = MFAService.verify_backup_code(self.mfa_settings, "INVALID") result = MFAService.verify_backup_code(self.mfa_settings, "INVALID")
self.assertFalse(result) assert not result
def test_verify_backup_code_no_codes(self): def test_verify_backup_code_no_codes(self):
"""Test backup code verification with no backup codes.""" """Test backup code verification with no backup codes."""
self.mfa_settings.backup_codes = None self.mfa_settings.backup_codes = None
result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234") result = MFAService.verify_backup_code(self.mfa_settings, "ABCD1234")
self.assertFalse(result)
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings') assert not result
@patch('services.mfa_service.MFAService.verify_totp')
@patch('services.mfa_service.MFAService.generate_backup_codes') @patch("services.mfa_service.MFAService.get_or_create_mfa_settings")
@patch('services.mfa_service.db.session') @patch("services.mfa_service.MFAService.verify_totp")
@patch("services.mfa_service.MFAService.generate_backup_codes")
@patch("services.mfa_service.db.session")
def test_setup_mfa_success(self, mock_session, mock_gen_codes, mock_verify, mock_get_settings): def test_setup_mfa_success(self, mock_session, mock_gen_codes, mock_verify, mock_get_settings):
"""Test successful MFA setup.""" """Test successful MFA setup."""
mock_get_settings.return_value = self.mfa_settings mock_get_settings.return_value = self.mfa_settings
self.mfa_settings.secret = "test_secret" self.mfa_settings.secret = "test_secret"
mock_verify.return_value = True mock_verify.return_value = True
mock_gen_codes.return_value = ["CODE1", "CODE2"] mock_gen_codes.return_value = ["CODE1", "CODE2"]
result = MFAService.setup_mfa(self.account, "123456") result = MFAService.setup_mfa(self.account, "123456")
self.assertTrue(self.mfa_settings.enabled)
self.assertEqual(self.mfa_settings.backup_codes, json.dumps(["CODE1", "CODE2"]))
self.assertIsNotNone(self.mfa_settings.setup_at)
self.assertEqual(result["backup_codes"], ["CODE1", "CODE2"])
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings') assert self.mfa_settings.enabled
assert self.mfa_settings.backup_codes == json.dumps(["CODE1", "CODE2"])
assert self.mfa_settings.setup_at is not None
assert result["backup_codes"] == ["CODE1", "CODE2"]
@patch("services.mfa_service.MFAService.get_or_create_mfa_settings")
def test_setup_mfa_already_enabled(self, mock_get_settings): def test_setup_mfa_already_enabled(self, mock_get_settings):
"""Test MFA setup when already enabled.""" """Test MFA setup when already enabled."""
self.mfa_settings.enabled = True self.mfa_settings.enabled = True
mock_get_settings.return_value = self.mfa_settings mock_get_settings.return_value = self.mfa_settings
with self.assertRaises(ValueError) as context: with pytest.raises(ValueError) as context:
MFAService.setup_mfa(self.account, "123456") MFAService.setup_mfa(self.account, "123456")
self.assertIn("already enabled", str(context.exception))
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings') assert "already enabled" in str(context.value)
@patch("services.mfa_service.MFAService.get_or_create_mfa_settings")
def test_setup_mfa_no_secret(self, mock_get_settings): def test_setup_mfa_no_secret(self, mock_get_settings):
"""Test MFA setup without secret.""" """Test MFA setup without secret."""
mock_get_settings.return_value = self.mfa_settings mock_get_settings.return_value = self.mfa_settings
with self.assertRaises(ValueError) as context: with pytest.raises(ValueError) as context:
MFAService.setup_mfa(self.account, "123456") MFAService.setup_mfa(self.account, "123456")
self.assertIn("secret not generated", str(context.exception))
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings') assert "secret not generated" in str(context.value)
@patch('services.mfa_service.MFAService.verify_totp')
@patch("services.mfa_service.MFAService.get_or_create_mfa_settings")
@patch("services.mfa_service.MFAService.verify_totp")
def test_setup_mfa_invalid_token(self, mock_verify, mock_get_settings): def test_setup_mfa_invalid_token(self, mock_verify, mock_get_settings):
"""Test MFA setup with invalid TOTP token.""" """Test MFA setup with invalid TOTP token."""
mock_get_settings.return_value = self.mfa_settings mock_get_settings.return_value = self.mfa_settings
self.mfa_settings.secret = "test_secret" self.mfa_settings.secret = "test_secret"
mock_verify.return_value = False mock_verify.return_value = False
with self.assertRaises(ValueError) as context: with pytest.raises(ValueError) as context:
MFAService.setup_mfa(self.account, "invalid") MFAService.setup_mfa(self.account, "invalid")
self.assertIn("Invalid TOTP token", str(context.exception))
@patch('services.mfa_service.db.session') assert "Invalid TOTP token" in str(context.value)
@patch("services.mfa_service.db.session")
def test_is_mfa_required_enabled(self, mock_session): def test_is_mfa_required_enabled(self, mock_session):
"""Test MFA requirement check when enabled.""" """Test MFA requirement check when enabled."""
self.mfa_settings.enabled = True self.mfa_settings.enabled = True
self.mfa_settings.secret = "test_secret" self.mfa_settings.secret = "test_secret"
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.is_mfa_required(self.account) result = MFAService.is_mfa_required(self.account)
self.assertTrue(result)
@patch('services.mfa_service.db.session') assert result
@patch("services.mfa_service.db.session")
def test_is_mfa_required_disabled(self, mock_session): def test_is_mfa_required_disabled(self, mock_session):
"""Test MFA requirement check when disabled.""" """Test MFA requirement check when disabled."""
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.is_mfa_required(self.account) result = MFAService.is_mfa_required(self.account)
self.assertFalse(result)
@patch('services.mfa_service.db.session') assert not result
@patch("services.mfa_service.db.session")
def test_is_mfa_required_no_settings(self, mock_session): def test_is_mfa_required_no_settings(self, mock_session):
"""Test MFA requirement check with no settings.""" """Test MFA requirement check with no settings."""
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.query.return_value.filter_by.return_value.first.return_value = None
result = MFAService.is_mfa_required(self.account) result = MFAService.is_mfa_required(self.account)
self.assertFalse(result)
@patch('services.mfa_service.db.session') assert not result
@patch('services.mfa_service.MFAService.verify_totp')
@patch('services.mfa_service.MFAService.verify_backup_code') @patch("services.mfa_service.db.session")
@patch("services.mfa_service.MFAService.verify_totp")
@patch("services.mfa_service.MFAService.verify_backup_code")
def test_authenticate_with_mfa_totp_success(self, mock_verify_backup, mock_verify_totp, mock_session): def test_authenticate_with_mfa_totp_success(self, mock_verify_backup, mock_verify_totp, mock_session):
"""Test MFA authentication with valid TOTP.""" """Test MFA authentication with valid TOTP."""
self.mfa_settings.enabled = True self.mfa_settings.enabled = True
self.mfa_settings.secret = "test_secret" self.mfa_settings.secret = "test_secret"
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
mock_verify_totp.return_value = True mock_verify_totp.return_value = True
result = MFAService.authenticate_with_mfa(self.account, "123456") result = MFAService.authenticate_with_mfa(self.account, "123456")
self.assertTrue(result) assert result
mock_verify_totp.assert_called_once_with("test_secret", "123456") mock_verify_totp.assert_called_once_with("test_secret", "123456")
mock_verify_backup.assert_not_called() mock_verify_backup.assert_not_called()
@patch('services.mfa_service.db.session') @patch("services.mfa_service.db.session")
@patch('services.mfa_service.MFAService.verify_totp') @patch("services.mfa_service.MFAService.verify_totp")
@patch('services.mfa_service.MFAService.verify_backup_code') @patch("services.mfa_service.MFAService.verify_backup_code")
def test_authenticate_with_mfa_backup_success(self, mock_verify_backup, mock_verify_totp, mock_session): def test_authenticate_with_mfa_backup_success(self, mock_verify_backup, mock_verify_totp, mock_session):
"""Test MFA authentication with valid backup code.""" """Test MFA authentication with valid backup code."""
self.mfa_settings.enabled = True self.mfa_settings.enabled = True
@ -224,145 +226,133 @@ class TestMFAService(unittest.TestCase):
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
mock_verify_totp.return_value = False mock_verify_totp.return_value = False
mock_verify_backup.return_value = True mock_verify_backup.return_value = True
result = MFAService.authenticate_with_mfa(self.account, "BACKUP123") result = MFAService.authenticate_with_mfa(self.account, "BACKUP123")
self.assertTrue(result) assert result
mock_verify_totp.assert_called_once_with("test_secret", "BACKUP123") mock_verify_totp.assert_called_once_with("test_secret", "BACKUP123")
mock_verify_backup.assert_called_once_with(self.mfa_settings, "BACKUP123") mock_verify_backup.assert_called_once_with(self.mfa_settings, "BACKUP123")
@patch('services.mfa_service.db.session') @patch("services.mfa_service.db.session")
def test_authenticate_with_mfa_disabled(self, mock_session): def test_authenticate_with_mfa_disabled(self, mock_session):
"""Test MFA authentication when disabled.""" """Test MFA authentication when disabled."""
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.authenticate_with_mfa(self.account, "123456") result = MFAService.authenticate_with_mfa(self.account, "123456")
self.assertTrue(result)
@patch('services.mfa_service.db.session') assert result
@patch("services.mfa_service.db.session")
def test_get_mfa_status_enabled(self, mock_session): def test_get_mfa_status_enabled(self, mock_session):
"""Test getting MFA status when enabled.""" """Test getting MFA status when enabled."""
self.mfa_settings.enabled = True self.mfa_settings.enabled = True
self.mfa_settings.setup_at = datetime(2025, 1, 1, 12, 0, 0) self.mfa_settings.setup_at = datetime(2025, 1, 1, 12, 0, 0)
self.mfa_settings.backup_codes = json.dumps(["CODE1", "CODE2"]) self.mfa_settings.backup_codes = json.dumps(["CODE1", "CODE2"])
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.get_mfa_status(self.account) result = MFAService.get_mfa_status(self.account)
expected = { expected = {"enabled": True, "setup_at": "2025-01-01T12:00:00", "has_backup_codes": True}
"enabled": True, assert result == expected
"setup_at": "2025-01-01T12:00:00",
"has_backup_codes": True @patch("services.mfa_service.db.session")
}
self.assertEqual(result, expected)
@patch('services.mfa_service.db.session')
def test_get_mfa_status_no_settings(self, mock_session): def test_get_mfa_status_no_settings(self, mock_session):
"""Test getting MFA status with no settings.""" """Test getting MFA status with no settings."""
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.query.return_value.filter_by.return_value.first.return_value = None
result = MFAService.get_mfa_status(self.account) result = MFAService.get_mfa_status(self.account)
expected = { expected = {"enabled": False, "setup_at": None, "has_backup_codes": False}
"enabled": False, assert result == expected
"setup_at": None,
"has_backup_codes": False @patch("qrcode.QRCode")
} @patch("pyotp.TOTP")
self.assertEqual(result, expected)
@patch('qrcode.QRCode')
@patch('pyotp.TOTP')
def test_generate_qr_code(self, mock_totp_class, mock_qr_class): def test_generate_qr_code(self, mock_totp_class, mock_qr_class):
"""Test QR code generation.""" """Test QR code generation."""
# Mock TOTP # Mock TOTP
mock_totp = Mock() mock_totp = Mock()
mock_totp.provisioning_uri.return_value = "otpauth://totp/test" mock_totp.provisioning_uri.return_value = "otpauth://totp/test"
mock_totp_class.return_value = mock_totp mock_totp_class.return_value = mock_totp
# Mock QR code # Mock QR code
mock_qr = Mock() mock_qr = Mock()
mock_img = Mock() mock_img = Mock()
mock_qr.make_image.return_value = mock_img mock_qr.make_image.return_value = mock_img
mock_qr_class.return_value = mock_qr mock_qr_class.return_value = mock_qr
# Mock image buffer # Mock image buffer
with patch('io.BytesIO') as mock_buffer, \ with patch("io.BytesIO") as mock_buffer, patch("base64.b64encode") as mock_b64:
patch('base64.b64encode') as mock_b64:
mock_b64.return_value.decode.return_value = "base64data" mock_b64.return_value.decode.return_value = "base64data"
result = MFAService.generate_qr_code(self.account, "test_secret") result = MFAService.generate_qr_code(self.account, "test_secret")
self.assertEqual(result, "data:image/png;base64,base64data") assert result == "data:image/png;base64,base64data"
mock_totp.provisioning_uri.assert_called_once_with( mock_totp.provisioning_uri.assert_called_once_with(name=self.account.email, issuer_name="Dify")
name=self.account.email,
issuer_name="Dify" @patch("libs.password.compare_password")
) @patch("services.mfa_service.db.session")
@patch('libs.password.compare_password')
@patch('services.mfa_service.db.session')
def test_disable_mfa_success(self, mock_session, mock_compare_password): def test_disable_mfa_success(self, mock_session, mock_compare_password):
"""Test successful MFA disable.""" """Test successful MFA disable."""
mock_compare_password.return_value = True mock_compare_password.return_value = True
mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings mock_session.query.return_value.filter_by.return_value.first.return_value = self.mfa_settings
result = MFAService.disable_mfa(self.account, "correct_password") result = MFAService.disable_mfa(self.account, "correct_password")
self.assertTrue(result) assert result
self.assertFalse(self.mfa_settings.enabled) assert not self.mfa_settings.enabled
self.assertIsNone(self.mfa_settings.secret) assert self.mfa_settings.secret is None
self.assertIsNone(self.mfa_settings.backup_codes) assert self.mfa_settings.backup_codes is None
self.assertIsNone(self.mfa_settings.setup_at) assert self.mfa_settings.setup_at is None
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
@patch('libs.password.compare_password') @patch("libs.password.compare_password")
def test_disable_mfa_wrong_password(self, mock_compare_password): def test_disable_mfa_wrong_password(self, mock_compare_password):
"""Test MFA disable with wrong password.""" """Test MFA disable with wrong password."""
mock_compare_password.return_value = False mock_compare_password.return_value = False
result = MFAService.disable_mfa(self.account, "wrong_password") result = MFAService.disable_mfa(self.account, "wrong_password")
self.assertFalse(result)
@patch('libs.password.compare_password') assert not result
@patch('services.mfa_service.db.session')
@patch("libs.password.compare_password")
@patch("services.mfa_service.db.session")
def test_disable_mfa_no_settings(self, mock_session, mock_compare_password): def test_disable_mfa_no_settings(self, mock_session, mock_compare_password):
"""Test MFA disable when no settings exist.""" """Test MFA disable when no settings exist."""
mock_compare_password.return_value = True mock_compare_password.return_value = True
mock_session.query.return_value.filter_by.return_value.first.return_value = None mock_session.query.return_value.filter_by.return_value.first.return_value = None
result = MFAService.disable_mfa(self.account, "correct_password") result = MFAService.disable_mfa(self.account, "correct_password")
self.assertTrue(result) # Already disabled
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings') assert result # Already disabled
@patch('services.mfa_service.MFAService.generate_secret')
@patch('services.mfa_service.MFAService.generate_qr_code') @patch("services.mfa_service.MFAService.get_or_create_mfa_settings")
@patch('services.mfa_service.db.session') @patch("services.mfa_service.MFAService.generate_secret")
@patch("services.mfa_service.MFAService.generate_qr_code")
@patch("services.mfa_service.db.session")
def test_generate_mfa_setup_data_success(self, mock_session, mock_gen_qr, mock_gen_secret, mock_get_settings): def test_generate_mfa_setup_data_success(self, mock_session, mock_gen_qr, mock_gen_secret, mock_get_settings):
"""Test successful MFA setup data generation.""" """Test successful MFA setup data generation."""
mock_get_settings.return_value = self.mfa_settings mock_get_settings.return_value = self.mfa_settings
mock_gen_secret.return_value = "NEWSECRET123" mock_gen_secret.return_value = "NEWSECRET123"
mock_gen_qr.return_value = "data:image/png;base64,qrdata" mock_gen_qr.return_value = "data:image/png;base64,qrdata"
result = MFAService.generate_mfa_setup_data(self.account) result = MFAService.generate_mfa_setup_data(self.account)
self.assertEqual(result["secret"], "NEWSECRET123") assert result["secret"] == "NEWSECRET123"
self.assertEqual(result["qr_code"], "data:image/png;base64,qrdata") assert result["qr_code"] == "data:image/png;base64,qrdata"
self.assertEqual(self.mfa_settings.secret, "NEWSECRET123") assert self.mfa_settings.secret == "NEWSECRET123"
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
@patch('services.mfa_service.MFAService.get_or_create_mfa_settings') @patch("services.mfa_service.MFAService.get_or_create_mfa_settings")
def test_generate_mfa_setup_data_already_enabled(self, mock_get_settings): def test_generate_mfa_setup_data_already_enabled(self, mock_get_settings):
"""Test MFA setup data generation when already enabled.""" """Test MFA setup data generation when already enabled."""
self.mfa_settings.enabled = True self.mfa_settings.enabled = True
mock_get_settings.return_value = self.mfa_settings mock_get_settings.return_value = self.mfa_settings
with self.assertRaises(ValueError) as context: with pytest.raises(ValueError) as context:
MFAService.generate_mfa_setup_data(self.account) MFAService.generate_mfa_setup_data(self.account)
self.assertIn("already enabled", str(context.exception)) assert "already enabled" in str(context.value)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

@ -354,7 +354,7 @@ importers:
devDependencies: devDependencies:
'@antfu/eslint-config': '@antfu/eslint-config':
specifier: ^4.1.1 specifier: ^4.1.1
version: 4.12.0(@eslint-react/eslint-plugin@1.45.0(eslint@9.24.0(jiti@1.21.7))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3))(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(@vue/compiler-sfc@3.5.13)(eslint-plugin-react-hooks@5.2.0(eslint@9.24.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.19(eslint@9.24.0(jiti@1.21.7)))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.4.4)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1)) version: 4.12.0(@eslint-react/eslint-plugin@1.45.0(eslint@9.24.0(jiti@1.21.7))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3))(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(@vue/compiler-sfc@3.5.13)(eslint-plugin-react-hooks@5.2.0(eslint@9.24.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.19(eslint@9.24.0(jiti@1.21.7)))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.6.3)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))
'@chromatic-com/storybook': '@chromatic-com/storybook':
specifier: ^3.1.0 specifier: ^3.1.0
version: 3.2.6(react@19.1.0)(storybook@8.5.0) version: 3.2.6(react@19.1.0)(storybook@8.5.0)
@ -371,8 +371,8 @@ importers:
specifier: ^9.0.3 specifier: ^9.0.3
version: 9.6.0 version: 9.6.0
'@happy-dom/jest-environment': '@happy-dom/jest-environment':
specifier: ^17.4.4 specifier: ^17.6.3
version: 17.4.4 version: 17.6.3
'@next/eslint-plugin-next': '@next/eslint-plugin-next':
specifier: ~15.3.5 specifier: ~15.3.5
version: 15.3.5 version: 15.3.5
@ -410,10 +410,10 @@ importers:
specifier: ^10.4.0 specifier: ^10.4.0
version: 10.4.0 version: 10.4.0
'@testing-library/jest-dom': '@testing-library/jest-dom':
specifier: ^6.6.2 specifier: ^6.6.3
version: 6.6.3 version: 6.6.3
'@testing-library/react': '@testing-library/react':
specifier: ^16.0.1 specifier: ^16.3.0
version: 16.3.0(@testing-library/dom@10.4.0)(@types/react-dom@19.1.6(@types/react@19.1.8))(@types/react@19.1.8)(react-dom@19.1.0(react@19.1.0))(react@19.1.0) version: 16.3.0(@testing-library/dom@10.4.0)(@types/react-dom@19.1.6(@types/react@19.1.8))(@types/react@19.1.8)(react-dom@19.1.0(react@19.1.0))(react@19.1.0)
'@types/crypto-js': '@types/crypto-js':
specifier: ^4.2.2 specifier: ^4.2.2
@ -1740,9 +1740,9 @@ packages:
'@formatjs/intl-localematcher@0.5.10': '@formatjs/intl-localematcher@0.5.10':
resolution: {integrity: sha512-af3qATX+m4Rnd9+wHcjJ4w2ijq+rAVP3CCinJQvFv1kgSu1W6jypUmvleJxcewdxmutM8dmIRZFxO/IQBZmP2Q==} resolution: {integrity: sha512-af3qATX+m4Rnd9+wHcjJ4w2ijq+rAVP3CCinJQvFv1kgSu1W6jypUmvleJxcewdxmutM8dmIRZFxO/IQBZmP2Q==}
'@happy-dom/jest-environment@17.4.4': '@happy-dom/jest-environment@17.6.3':
resolution: {integrity: sha512-5imA+SpP7ZcIwE1u2swWZq6UJhyZIWNtlE/gnqhVz+y91G6hgF+t9hVSsWH29Tfib+wg/zC9ryJPDDyAuqXfEg==} resolution: {integrity: sha512-HXuHKvpHLo9/GQ/yKMmKFyS1AYL2t9pL67+GfpYZfOAb29qD80EMozi50zRZk82KmNRBcA2A0/ErjpOwUxJrNg==}
engines: {node: '>=18.0.0'} engines: {node: '>=20.0.0'}
'@headlessui/react@2.2.1': '@headlessui/react@2.2.1':
resolution: {integrity: sha512-daiUqVLae8CKVjEVT19P/izW0aGK0GNhMSAeMlrDebKmoVZHcRRwbxzgtnEadUVDXyBsWo9/UH4KHeniO+0tMg==} resolution: {integrity: sha512-daiUqVLae8CKVjEVT19P/izW0aGK0GNhMSAeMlrDebKmoVZHcRRwbxzgtnEadUVDXyBsWo9/UH4KHeniO+0tMg==}
@ -5554,9 +5554,9 @@ packages:
hachure-fill@0.5.2: hachure-fill@0.5.2:
resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==} resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==}
happy-dom@17.4.4: happy-dom@17.6.3:
resolution: {integrity: sha512-/Pb0ctk3HTZ5xEL3BZ0hK1AqDSAUuRQitOmROPHhfUYEWpmTImwfD8vFDGADmMAX0JYgbcgxWoLFKtsWhcpuVA==} resolution: {integrity: sha512-UVIHeVhxmxedbWPCfgS55Jg2rDfwf2BCKeylcPSqazLz5w3Kri7Q4xdBJubsr/+VUzFLh0VjIvh13RaDA2/Xug==}
engines: {node: '>=18.0.0'} engines: {node: '>=20.0.0'}
has-bigints@1.1.0: has-bigints@1.1.0:
resolution: {integrity: sha512-R3pbpkcIqv2Pm3dUwgjclDRVmWpTJW2DcMzcIhEXEx1oh/CEMObMm3KLmRJOdvhM7o4uQBnwr8pzRK2sJWIqfg==} resolution: {integrity: sha512-R3pbpkcIqv2Pm3dUwgjclDRVmWpTJW2DcMzcIhEXEx1oh/CEMObMm3KLmRJOdvhM7o4uQBnwr8pzRK2sJWIqfg==}
@ -8905,7 +8905,7 @@ snapshots:
'@jridgewell/gen-mapping': 0.3.8 '@jridgewell/gen-mapping': 0.3.8
'@jridgewell/trace-mapping': 0.3.25 '@jridgewell/trace-mapping': 0.3.25
'@antfu/eslint-config@4.12.0(@eslint-react/eslint-plugin@1.45.0(eslint@9.24.0(jiti@1.21.7))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3))(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(@vue/compiler-sfc@3.5.13)(eslint-plugin-react-hooks@5.2.0(eslint@9.24.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.19(eslint@9.24.0(jiti@1.21.7)))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.4.4)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))': '@antfu/eslint-config@4.12.0(@eslint-react/eslint-plugin@1.45.0(eslint@9.24.0(jiti@1.21.7))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3))(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(@vue/compiler-sfc@3.5.13)(eslint-plugin-react-hooks@5.2.0(eslint@9.24.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.4.19(eslint@9.24.0(jiti@1.21.7)))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.6.3)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))':
dependencies: dependencies:
'@antfu/install-pkg': 1.0.0 '@antfu/install-pkg': 1.0.0
'@clack/prompts': 0.10.1 '@clack/prompts': 0.10.1
@ -8914,7 +8914,7 @@ snapshots:
'@stylistic/eslint-plugin': 4.2.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3) '@stylistic/eslint-plugin': 4.2.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)
'@typescript-eslint/eslint-plugin': 8.29.1(@typescript-eslint/parser@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3) '@typescript-eslint/eslint-plugin': 8.29.1(@typescript-eslint/parser@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)
'@typescript-eslint/parser': 8.29.1(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3) '@typescript-eslint/parser': 8.29.1(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)
'@vitest/eslint-plugin': 1.1.42(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.4.4)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1)) '@vitest/eslint-plugin': 1.1.42(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.6.3)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))
ansis: 3.17.0 ansis: 3.17.0
cac: 6.7.14 cac: 6.7.14
eslint: 9.24.0(jiti@1.21.7) eslint: 9.24.0(jiti@1.21.7)
@ -10225,12 +10225,12 @@ snapshots:
dependencies: dependencies:
tslib: 2.8.1 tslib: 2.8.1
'@happy-dom/jest-environment@17.4.4': '@happy-dom/jest-environment@17.6.3':
dependencies: dependencies:
'@jest/environment': 29.7.0 '@jest/environment': 29.7.0
'@jest/fake-timers': 29.7.0 '@jest/fake-timers': 29.7.0
'@jest/types': 29.6.3 '@jest/types': 29.6.3
happy-dom: 17.4.4 happy-dom: 17.6.3
jest-mock: 29.7.0 jest-mock: 29.7.0
jest-util: 29.7.0 jest-util: 29.7.0
@ -12341,11 +12341,11 @@ snapshots:
'@unrs/resolver-binding-win32-x64-msvc@1.4.1': '@unrs/resolver-binding-win32-x64-msvc@1.4.1':
optional: true optional: true
'@vitest/eslint-plugin@1.1.42(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.4.4)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))': '@vitest/eslint-plugin@1.1.42(@typescript-eslint/utils@8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3))(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)(vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.6.3)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))':
dependencies: dependencies:
'@typescript-eslint/utils': 8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3) '@typescript-eslint/utils': 8.36.0(eslint@9.24.0(jiti@1.21.7))(typescript@5.8.3)
eslint: 9.24.0(jiti@1.21.7) eslint: 9.24.0(jiti@1.21.7)
vitest: 3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.4.4)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1) vitest: 3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.6.3)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1)
optionalDependencies: optionalDependencies:
typescript: 5.8.3 typescript: 5.8.3
@ -13183,7 +13183,7 @@ snapshots:
code-inspector-core@0.18.3: code-inspector-core@0.18.3:
dependencies: dependencies:
'@vue/compiler-dom': 3.5.13 '@vue/compiler-dom': 3.5.13
chalk: 4.1.1 chalk: 4.1.2
dotenv: 16.5.0 dotenv: 16.5.0
launch-ide: 1.0.1 launch-ide: 1.0.1
portfinder: 1.0.35 portfinder: 1.0.35
@ -14963,7 +14963,7 @@ snapshots:
hachure-fill@0.5.2: {} hachure-fill@0.5.2: {}
happy-dom@17.4.4: happy-dom@17.6.3:
dependencies: dependencies:
webidl-conversions: 7.0.0 webidl-conversions: 7.0.0
whatwg-mimetype: 3.0.0 whatwg-mimetype: 3.0.0
@ -15960,7 +15960,7 @@ snapshots:
launch-ide@1.0.1: launch-ide@1.0.1:
dependencies: dependencies:
chalk: 4.1.1 chalk: 4.1.2
dotenv: 16.5.0 dotenv: 16.5.0
layout-base@1.0.2: {} layout-base@1.0.2: {}
@ -18881,7 +18881,7 @@ snapshots:
terser: 5.39.0 terser: 5.39.0
yaml: 2.7.1 yaml: 2.7.1
vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.4.4)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1): vitest@3.1.1(@types/debug@4.1.12)(@types/node@18.15.0)(happy-dom@17.6.3)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1):
dependencies: dependencies:
'@vitest/expect': 3.1.1 '@vitest/expect': 3.1.1
'@vitest/mocker': 3.1.1(vite@6.2.7(@types/node@18.15.0)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1)) '@vitest/mocker': 3.1.1(vite@6.2.7(@types/node@18.15.0)(jiti@1.21.7)(sass@1.86.3)(terser@5.39.0)(yaml@2.7.1))
@ -18906,7 +18906,7 @@ snapshots:
optionalDependencies: optionalDependencies:
'@types/debug': 4.1.12 '@types/debug': 4.1.12
'@types/node': 18.15.0 '@types/node': 18.15.0
happy-dom: 17.4.4 happy-dom: 17.6.3
transitivePeerDependencies: transitivePeerDependencies:
- jiti - jiti
- less - less

Loading…
Cancel
Save