diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 31782b9554..8be8f7e893 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -6,13 +6,13 @@ from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api -from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError -from controllers.console.error import ( - AccountInFreezeError, - AccountNotFound, - EmailSendIpLimitError, -) -from controllers.console.wraps import setup_required +from controllers.console.auth.error import (EmailCodeError, InvalidEmailError, + InvalidTokenError, + PasswordMismatchError) +from controllers.console.error import (AccountInFreezeError, AccountNotFound, + EmailSendIpLimitError) +from controllers.console.wraps import (email_password_login_enabled, + setup_required) from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip @@ -20,7 +20,7 @@ from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError +from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.feature_service import FeatureService @@ -131,8 +131,6 @@ class ForgotPasswordResetApi(Resource): pass except AccountRegisterError as are: raise AccountInFreezeError() - except WorkspacesLimitExceededError: - pass return {"result": "success"} diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 4dc04a0455..58de237fbf 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -23,7 +23,7 @@ from controllers.console.error import ( NotAllowedCreateWorkspace, WorkspacesLimitExceeded, ) -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password @@ -39,6 +39,7 @@ class LoginApi(Resource): """Resource for user login.""" @setup_required + @email_password_login_enabled def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -120,6 +121,7 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index b7ba81fba2..4fbf4092fd 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,3 +1,4 @@ +import logging from functools import wraps from flask_login import current_user # type: ignore @@ -8,6 +9,9 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required from models import InstalledApp +from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService def installed_app_required(view=None): @@ -48,6 +52,32 @@ def installed_app_required(view=None): return decorator +def user_allowed_to_access_app(view=None): + def decorator(view): + @wraps(view) + def decorated(installed_app: InstalledApp, *args, **kwargs): + feature = FeatureService.get_system_features() + if feature.webapp_auth.enabled: + app_id = installed_app.app_id + app_code = AppService.get_app_code_by_id(app_id) + logging.info(f"app_id: {app_id}, app_code: {app_code}") + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( + user_id=str(current_user.id), + app_code=app_code, + ) + logging.info(f"res: {res}") + if not res: + raise ValueError("User not allowed to access this app") + + return view(installed_app, *args, **kwargs) + + return decorated + if view: + return decorator(view) + return decorator + + class InstalledAppResource(Resource): # must be reversed if there are multiple decorators - method_decorators = [installed_app_required, account_initialization_required, login_required] + + method_decorators = [user_allowed_to_access_app, installed_app_required, account_initialization_required, login_required] diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 510ea242a9..11b7140b1b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -11,7 +11,8 @@ from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService -from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout +from .error import (NotInitValidateError, NotSetupError, + UnauthorizedAndForceLogout) def account_initialization_required(view): @@ -165,3 +166,16 @@ def enterprise_license_required(view): return view(*args, **kwargs) return decorated + + +def email_password_login_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_email_password_login: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 20e071c834..823b80cfce 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,12 +1,17 @@ -from flask_restful import marshal_with # type: ignore +import logging + +from flask import request +from flask_restful import Resource, marshal_with, reqparse # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource +from libs.passport import PassportService from models.model import App, AppMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService class AppParameterApi(WebApiResource): @@ -42,5 +47,54 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) +class AppAccessMode(Resource): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + + return {"accessMode": res.access_mode} + + +class AppWebAuthPermission(Resource): + def get(self): + user_id = "visitor" + try: + auth_header = request.headers.get("Authorization") + if auth_header is None: + raise + if " " not in auth_header: + raise + + auth_scheme, tk = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise + + decoded = PassportService().verify(tk) + user_id = decoded.get("user_id", "visitor") + except Exception as e: + pass + + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + logging.info(f"App ID: {app_id}, User ID: {user_id}") + + app_code = AppService.get_app_code_by_id(app_id) + logging.info(f"App code: {app_code}") + + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + return {"result": res} + + api.add_resource(AppParameterApi, "/parameters") api.add_resource(AppMeta, "/meta") +# webapp auth apis +api.add_resource(AppAccessMode, "/webapp/access-mode") +api.add_resource(AppWebAuthPermission, "/webapp/permission") diff --git a/api/services/account_service.py b/api/services/account_service.py index 1e5121b072..c6b8d0e9b7 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -405,10 +405,8 @@ class AccountService: raise PasswordResetRateLimitExceededError() - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - token = TokenManager.generate_token( - account=account, email=email, token_type="reset_password", additional_data={"code": code} - ) + code, token = cls.generate_reset_password_token(account_email, account) + send_reset_password_mail_task.delay( language=language, to=account_email, @@ -417,6 +415,22 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def generate_reset_password_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") diff --git a/api/services/app_service.py b/api/services/app_service.py index 9359bb2844..e6a1ae32a9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -19,7 +19,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -384,3 +384,15 @@ class AppService: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta + + @staticmethod + def get_app_code_by_id(app_id: str) -> str: + """ + Get app code by app id + :param app_id: app id + :return: app code + """ + site = db.session.query(Site).filter(Site.app_id == app_id).first() + if not site: + raise ValueError(f"App with id {app_id} not found") + return str(site.code) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 611f69e715..a1b6216b25 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -45,12 +45,12 @@ class EnterpriseService: if not data: raise ValueError("No data found.") - if not isinstance(data['accessModes'], dict): + if not isinstance(data["accessModes"], dict): logging.info(f"Batch get app access mode by id returns data: {data}") raise ValueError("Invalid data format.") ret = {} - for key, value in data['accessModes'].items(): + for key, value in data["accessModes"].items(): curr = WebAppSettings() curr.access_mode = value ret[key] = curr diff --git a/api/services/feature_service.py b/api/services/feature_service.py index a366a9d074..199b63e679 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -187,6 +187,9 @@ class FeatureService: if "SSOEnforcedForSignin" in enterprise_info: features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"] + if "SSOEnforcedForSigninProtocol" in enterprise_info: + features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"] + if "EnableEmailCodeLogin" in enterprise_info: features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"] diff --git a/web/app/reset-password/check-code/page.tsx b/web/app/reset-password/check-code/page.tsx index ca53b68750..b8cb87629b 100644 --- a/web/app/reset-password/check-code/page.tsx +++ b/web/app/reset-password/check-code/page.tsx @@ -39,7 +39,11 @@ export default function CheckCode() { } setIsLoading(true) const ret = await verifyResetPasswordCode({ email, code, token }) - ret.is_valid && router.push(`/reset-password/set-password?${searchParams.toString()}`) + if (ret.is_valid) { + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(ret.token)) + router.push(`/reset-password/set-password?${params.toString()}`) + } } catch (error) { console.error(error) } finally { diff --git a/web/service/common.ts b/web/service/common.ts index 7ccc6eb5c3..06562ece0c 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -338,7 +338,7 @@ export const sendResetPasswordCode = (email: string, language = 'en-US') => post('/forgot-password', { body: { email, language } }) export const verifyResetPasswordCode = (body: { email: string; code: string; token: string }) => - post('/forgot-password/validity', { body }) + post('/forgot-password/validity', { body }) export const sendDeleteAccountCode = () => get('/account/delete/verify')