From 46d43e6758267e5b4337b30ead15bc73ac19f900 Mon Sep 17 00:00:00 2001 From: GareArc Date: Mon, 7 Apr 2025 17:03:26 -0400 Subject: [PATCH 01/14] feat: add web app auth --- api/controllers/web/error.py | 12 +- api/controllers/web/login.py | 118 ++++++++++++++++++ api/controllers/web/passport.py | 10 +- api/controllers/web/wraps.py | 54 +++++--- api/services/enterprise/enterprise_service.py | 28 ++++- api/services/feature_service.py | 14 +++ api/services/webapp_auth_service.py | 103 +++++++++++++++ 7 files changed, 309 insertions(+), 30 deletions(-) create mode 100644 api/controllers/web/login.py create mode 100644 api/services/webapp_auth_service.py diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 9fe5d08d54..4909694d26 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class WebSSOAuthRequiredError(BaseHTTPException): - error_code = "web_sso_auth_required" - description = "Web SSO authentication required." +class WebAppAuthRequiredError(BaseHTTPException): + error_code = "web_auth_required" + description = "Web app authentication required." + code = 401 + + +class WebAppAuthFailedError(BaseHTTPException): + error_code = "web_app_auth_failed" + description = "You do not have permission to access this web app." code = 401 diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py new file mode 100644 index 0000000000..235fcaf8cc --- /dev/null +++ b/api/controllers/web/login.py @@ -0,0 +1,118 @@ +from typing import cast + +import flask_login +from flask import request +from flask_restful import Resource, reqparse +from jwt import InvalidTokenError # type: ignore +from web import api + +import services +from controllers.console.auth.error import (EmailCodeError, + EmailOrPasswordMismatchError, + InvalidEmailError) +from controllers.console.error import AccountBannedError, AccountNotFound +from controllers.console.wraps import setup_required +from libs.helper import email +from libs.password import valid_password +from models.account import Account +from services.account_service import AccountService +from services.webapp_auth_service import Unauthorized, WebAppAuthService + + +class LoginApi(Resource): + """Resource for web app email/password login.""" + + def post(self): + """Authenticate user and login.""" + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + args = parser.parse_args() + + app_code = request.headers.get("X-App-Code") + if app_code is None: + raise Unauthorized("X-App-Code header is missing.") + + try: + account = WebAppAuthService.authenticate(args["email"], args["password"]) + except services.errors.account.AccountLoginError: + raise AccountBannedError() + except services.errors.account.AccountPasswordError: + raise EmailOrPasswordMismatchError() + except services.errors.account.AccountNotFoundError: + raise AccountNotFound() + + token = WebAppAuthService.login(account=account, app_code=app_code) + return {"result": "success", "token": token} + + +class LogoutApi(Resource): + @setup_required + def get(self): + account = cast(Account, flask_login.current_user) + if isinstance(account, flask_login.AnonymousUserMixin): + return {"result": "success"} + flask_login.logout_user() + return {"result": "success"} + + +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = WebAppAuthService.get_user_through_email(args["email"]) + if account is None: + raise AccountNotFound() + else: + token = WebAppAuthService.send_email_code_login_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + app_code = request.headers.get("X-App-Code") + if app_code is None: + raise Unauthorized("X-App-Code header is missing.") + + token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailCodeError() + + WebAppAuthService.revoke_email_code_login_token(args["token"]) + account = WebAppAuthService.get_user_through_email(user_email) + if not account: + raise AccountNotFound() + + token = WebAppAuthService.login(account=account, app_code=app_code) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "token": token} + + +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 4625c1f43d..3c1f0a415f 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -5,7 +5,7 @@ from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -22,10 +22,10 @@ class PassportResource(Resource): if app_code is None: raise Unauthorized("X-App-Code header is missing.") - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) + if not app_settings or not app_settings.access_mode == "public": + raise WebAppAuthRequiredError() # get site from db and check if it is normal site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1b4d263bee..482d7859fa 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,7 +4,8 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import (WebAppAuthFailedError, + WebAppAuthRequiredError) from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -57,35 +58,48 @@ def decode_jwt_token(): if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features, app_code) + # for enterprise webapp auth + app_web_auth_enabled = False + if system_features.webapp_auth.enabled: + app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + + _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) + _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) return app_model, end_user except Unauthorized as e: - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + if app_web_auth_enabled: + raise WebAppAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features, app_code): - app_web_sso_enabled = False - - # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - source = decoded.get("token_source") - if not source or source != "sso": - raise WebSSOAuthRequiredError() +def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + # Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login + if system_webapp_auth_enabled and app_web_auth_enabled: + source = decoded.get("token_source") + if not source or source != "webapp": + raise WebAppAuthRequiredError() - # Check if SSO is not enforced for web, and if the token source is SSO, + # Check if authentication is not enforced for web, and if the token source is webapp, # raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + if not system_webapp_auth_enabled or not app_web_auth_enabled: source = decoded.get("token_source") - if source and source == "sso": - raise Unauthorized("sso token expired.") + if source and source == "webapp": + raise Unauthorized("webapp token expired.") + + +def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + if system_webapp_auth_enabled and app_web_auth_enabled: + # Check if the user is allowed to access the web app + user_id = decoded.get("user_id") + if not user_id: + raise WebAppAuthRequiredError() + + if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + raise WebAppAuthFailedError() class WebApiResource(Resource): diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index abc01ddf8f..2ff3b3348a 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,11 +1,35 @@ +from pydantic import BaseModel, Field + from services.enterprise.base import EnterpriseRequest +class WebAppSettings(BaseModel): + access_mode: str = Field( + description="Access mode for the web app. Can be 'public' or 'private'", + default="private", + alias="access_mode", + ) + + class EnterpriseService: @classmethod def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") @classmethod - def get_app_web_sso_enabled(cls, app_code): - return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") + def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool: + if not app_id and not app_code: + raise ValueError("Either app_id or app_code must be provided.") + + return EnterpriseRequest.send_request( + "GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}" + ) + + @classmethod + def get_web_app_settings(cls, app_code: str = None, app_id: str = None): + if not app_code and not app_id: + raise ValueError("Either app_code or app_id must be provided.") + data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}") + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 959e0221b5..f37f4e0d92 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -44,6 +44,13 @@ class BrandingModel(BaseModel): favicon: str = "" +class WebAppAuthModel(BaseModel): + enabled: bool = False + allow_sso: bool = False + allow_email_code_login: bool = False + allow_email_password_login: bool = False + + class FeatureModel(BaseModel): billing: BillingModel = BillingModel() members: LimitationModel = LimitationModel(size=0, limit=1) @@ -75,6 +82,7 @@ class SystemFeatureModel(BaseModel): is_email_setup: bool = False license: LicenseModel = LicenseModel() branding: BrandingModel = BrandingModel() + webapp_auth: WebAppAuthModel = WebAppAuthModel() class FeatureService: @@ -101,6 +109,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: system_features.enable_web_sso_switch_component = True system_features.branding.enabled = True + system_features.webapp_auth.enabled = True cls._fulfill_params_from_enterprise(system_features) return system_features @@ -194,6 +203,11 @@ class FeatureService: features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "") features.branding.favicon = enterprise_info["Branding"].get("favicon", "") + if "WebAppAuth" in enterprise_info: + features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False) + features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False) + features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False) + if "License" in enterprise_info: license_info = enterprise_info["License"] diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py new file mode 100644 index 0000000000..65501bbffa --- /dev/null +++ b/api/services/webapp_auth_service.py @@ -0,0 +1,103 @@ +import random +from datetime import UTC, datetime, timedelta +from typing import Any, Optional, cast + +from werkzeug.exceptions import NotFound, Unauthorized + +from configs import dify_config +from extensions.ext_database import db +from libs.helper import TokenManager +from libs.passport import PassportService +from libs.password import compare_password +from models.account import Account, AccountStatus +from models.model import Site +from services.errors.account import (AccountLoginError, AccountNotFoundError, + AccountPasswordError) +from tasks.mail_email_code_login import send_email_code_login_mail_task + + +class WebAppAuthService: + """Service for web app authentication.""" + + @staticmethod + def authenticate(email: str, password: str) -> Account: + """authenticate account with email and password""" + + account = Account.query.filter_by(email=email).first() + if not account: + raise AccountNotFoundError() + + if account.status == AccountStatus.BANNED.value: + raise AccountLoginError("Account is banned.") + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") + + return cast(Account, account) + + @staticmethod + def login(account: Account, app_code: str) -> str: + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise NotFound("Site not found.") + + access_token = WebAppAuthService._get_account_jwt_token(account=account, site=site) + + return access_token + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + return account + + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + email = account.email if account else email + if email is None: + raise ValueError("Email must be provided.") + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "webapp_email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "webapp_email_code_login") + + @staticmethod + def _get_account_jwt_token(account: Account, site: Site) -> str: + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24) + exp = int(exp_dt.timestamp()) + + payload = { + "iss": site.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "user_id": account.id, + "token_source": "webapp", + "exp": exp, + } + + token: str = PassportService().issue(payload) + return token From 5e50570739060640f89e9cc0817320bad64b0c40 Mon Sep 17 00:00:00 2001 From: GareArc Date: Mon, 7 Apr 2025 18:41:02 -0400 Subject: [PATCH 02/14] fix: update webapp jwt claim and add user accessibility support --- api/controllers/web/login.py | 19 +++++++++--- api/services/webapp_auth_service.py | 48 +++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 235fcaf8cc..955c781989 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -5,6 +5,7 @@ from flask import request from flask_restful import Resource, reqparse from jwt import InvalidTokenError # type: ignore from web import api +from werkzeug.exceptions import BadRequest import services from controllers.console.auth.error import (EmailCodeError, @@ -16,7 +17,7 @@ from libs.helper import email from libs.password import valid_password from models.account import Account from services.account_service import AccountService -from services.webapp_auth_service import Unauthorized, WebAppAuthService +from services.webapp_auth_service import WebAppAuthService class LoginApi(Resource): @@ -31,7 +32,7 @@ class LoginApi(Resource): app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized("X-App-Code header is missing.") + raise BadRequest("X-App-Code header is missing.") try: account = WebAppAuthService.authenticate(args["email"], args["password"]) @@ -42,7 +43,11 @@ class LoginApi(Resource): except services.errors.account.AccountNotFoundError: raise AccountNotFound() - token = WebAppAuthService.login(account=account, app_code=app_code) + WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) + + end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code) + + token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) return {"result": "success", "token": token} @@ -90,7 +95,7 @@ class EmailCodeLoginApi(Resource): user_email = args["email"] app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized("X-App-Code header is missing.") + raise BadRequest("X-App-Code header is missing.") token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: @@ -107,7 +112,11 @@ class EmailCodeLoginApi(Resource): if not account: raise AccountNotFound() - token = WebAppAuthService.login(account=account, app_code=app_code) + WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) + + end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code) + + token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "token": token} diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 65501bbffa..24d1177d87 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -5,14 +5,18 @@ from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config +from controllers.web.error import (WebAppAuthFailedError, + WebAppAuthRequiredError) from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService from libs.password import compare_password from models.account import Account, AccountStatus -from models.model import Site +from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import (AccountLoginError, AccountNotFoundError, AccountPasswordError) +from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -35,13 +39,13 @@ class WebAppAuthService: return cast(Account, account) - @staticmethod - def login(account: Account, app_code: str) -> str: + @classmethod + def login(cls, account: Account, app_code: str, end_user_id: str) -> str: site = db.session.query(Site).filter(Site.code == app_code).first() if not site: raise NotFound("Site not found.") - access_token = WebAppAuthService._get_account_jwt_token(account=account, site=site) + access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id) return access_token @@ -84,8 +88,39 @@ class WebAppAuthService: def revoke_email_code_login_token(cls, token: str): TokenManager.revoke_token(token, "webapp_email_code_login") - @staticmethod - def _get_account_jwt_token(account: Account, site: Site) -> str: + @classmethod + def create_end_user(cls, app_code, email) -> EndUser: + site = db.session.query(Site).filter(Site.code == app_code).first() + app_model = db.session.query(App).filter(App.id == site.app_id).first() + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=False, + session_id=email, + name="enterpriseuser", + external_user_id="enterpriseuser" + ) + db.session.add(end_user) + db.session.commit() + + return end_user + + @classmethod + def _validate_user_accessibility(cls, account: Account, app_code: str): + """Check if the user is allowed to access the app.""" + system_features = FeatureService.get_system_features() + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) + if not app_settings or not app_settings.access_mode == "public": + raise WebAppAuthRequiredError() + if app_settings.access_mode == "private" and not EnterpriseService.is_user_allowed_to_access_webapp( + account.id, app_code=app_code + ): + raise WebAppAuthFailedError() + + @classmethod + def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24) exp = int(exp_dt.timestamp()) @@ -95,6 +130,7 @@ class WebAppAuthService: "app_id": site.app_id, "app_code": site.code, "user_id": account.id, + "end_user_id": end_user_id, "token_source": "webapp", "exp": exp, } From e9a207b38e68935b925b9bc72c022581a3fa59dd Mon Sep 17 00:00:00 2001 From: GareArc Date: Wed, 9 Apr 2025 16:30:41 -0400 Subject: [PATCH 03/14] fix: adjust enterprise api --- api/controllers/console/app/app.py | 19 ++++++++---- api/controllers/inner_api/__init__.py | 2 +- api/controllers/inner_api/mail.py | 8 +++-- api/controllers/web/login.py | 26 ++++++---------- api/controllers/web/wraps.py | 14 ++++++--- api/fields/app_fields.py | 3 ++ api/services/enterprise/enterprise_service.py | 31 ++++++++++++------- api/services/enterprise/mail_service.py | 14 ++------- api/services/feature_service.py | 8 +++-- api/services/webapp_auth_service.py | 13 +++----- api/tasks/mail_enterprise_task.py | 8 ++--- 11 files changed, 77 insertions(+), 69 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4aa10ac6e9..ce6da4af79 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -17,15 +17,13 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db -from fields.app_fields import ( - app_detail_fields, - app_detail_fields_with_site, - app_pagination_fields, -) +from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from libs.login import login_required from models import Account, App from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] @@ -67,7 +65,12 @@ class AppListApi(Resource): if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} - return marshal(app_pagination, app_pagination_fields) + if FeatureService.get_system_features().webapp_auth.enabled: + for app in app_pagination.items: + app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app.id)) + app.access_mode = app_setting.access_mode + + return marshal(app_pagination, app_pagination_fields), 200 @setup_required @login_required @@ -111,6 +114,10 @@ class AppApi(Resource): app_model = app_service.get_app(app_model) + if FeatureService.get_system_features().webapp_auth.enabled: + app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app_model.id)) + app_model.access_mode = app_setting.access_mode + return app_model @setup_required diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index dfedea582d..b1bb9d6545 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -5,5 +5,5 @@ from libs.external_api import ExternalApi bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) -from .workspace import workspace from . import mail +from .workspace import workspace diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 9f10356fb6..47cbcb713c 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,5 +1,7 @@ -from flask_restful import Resource # type: ignore -from flask_restful import reqparse +from flask_restful import ( + Resource, # type: ignore + reqparse, +) from controllers.console.wraps import setup_required from controllers.inner_api import api @@ -12,7 +14,7 @@ class EnterpriseMail(Resource): @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument("to", type=str, action='append', required=True) + parser.add_argument("to", type=str, action="append", required=True) parser.add_argument("subject", type=str, required=True) parser.add_argument("body", type=str, required=True) parser.add_argument("substitutions", type=dict, required=False) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 955c781989..4106e6a179 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,6 +1,3 @@ -from typing import cast - -import flask_login from flask import request from flask_restful import Resource, reqparse from jwt import InvalidTokenError # type: ignore @@ -8,14 +5,11 @@ from web import api from werkzeug.exceptions import BadRequest import services -from controllers.console.auth.error import (EmailCodeError, - EmailOrPasswordMismatchError, - InvalidEmailError) +from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError from controllers.console.error import AccountBannedError, AccountNotFound from controllers.console.wraps import setup_required from libs.helper import email from libs.password import valid_password -from models.account import Account from services.account_service import AccountService from services.webapp_auth_service import WebAppAuthService @@ -51,14 +45,14 @@ class LoginApi(Resource): return {"result": "success", "token": token} -class LogoutApi(Resource): - @setup_required - def get(self): - account = cast(Account, flask_login.current_user) - if isinstance(account, flask_login.AnonymousUserMixin): - return {"result": "success"} - flask_login.logout_user() - return {"result": "success"} +# class LogoutApi(Resource): +# @setup_required +# def get(self): +# account = cast(Account, flask_login.current_user) +# if isinstance(account, flask_login.AnonymousUserMixin): +# return {"result": "success"} +# flask_login.logout_user() +# return {"result": "success"} class EmailCodeLoginSendEmailApi(Resource): @@ -122,6 +116,6 @@ class EmailCodeLoginApi(Resource): api.add_resource(LoginApi, "/login") -api.add_resource(LogoutApi, "/logout") +# api.add_resource(LogoutApi, "/logout") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 482d7859fa..a009cd3288 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,8 +4,7 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import (WebAppAuthFailedError, - WebAppAuthRequiredError) +from controllers.web.error import WebAppAuthFailedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -61,7 +60,9 @@ def decode_jwt_token(): # for enterprise webapp auth app_web_auth_enabled = False if system_features.webapp_auth.enabled: - app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + app_web_auth_enabled = ( + EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + ) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) @@ -69,7 +70,9 @@ def decode_jwt_token(): return app_model, end_user except Unauthorized as e: if system_features.webapp_auth.enabled: - app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + app_web_auth_enabled = ( + EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + ) if app_web_auth_enabled: raise WebAppAuthRequiredError() @@ -77,7 +80,8 @@ def decode_jwt_token(): def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): - # Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login + # Check if authentication is enforced for web app, and if the token source is not webapp, + # raise an error and redirect to login if system_webapp_auth_enabled and app_web_auth_enabled: source = decoded.get("token_source") if not source or source != "webapp": diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 73800eab85..95eef8fed1 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -63,6 +63,7 @@ app_detail_fields = { "created_at": TimestampField, "updated_by": fields.String, "updated_at": TimestampField, + "access_mode": fields.String, } prompt_config_fields = { @@ -98,6 +99,7 @@ app_partial_fields = { "updated_by": fields.String, "updated_at": TimestampField, "tags": fields.List(fields.Nested(tag_fields)), + "access_mode": fields.String, } @@ -170,6 +172,7 @@ app_detail_fields_with_site = { "updated_by": fields.String, "updated_at": TimestampField, "deleted_tools": fields.List(fields.String), + "access_mode": fields.String, } app_site_fields = { diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 2ff3b3348a..21e4831715 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -7,7 +7,7 @@ class WebAppSettings(BaseModel): access_mode: str = Field( description="Access mode for the web app. Can be 'public' or 'private'", default="private", - alias="access_mode", + alias="accessMode", ) @@ -17,19 +17,28 @@ class EnterpriseService: return EnterpriseRequest.send_request("GET", "/info") @classmethod - def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool: - if not app_id and not app_code: - raise ValueError("Either app_id or app_code must be provided.") + def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: + params = {"userId": user_id, "appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) - return EnterpriseRequest.send_request( - "GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}" - ) + return data.get("result", False) @classmethod - def get_web_app_settings(cls, app_code: str = None, app_id: str = None): - if not app_code and not app_id: - raise ValueError("Either app_code or app_id must be provided.") - data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}") + def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: + if not app_id: + raise ValueError("app_id must be provided.") + params = {"appId": app_id} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: + if not app_code: + raise ValueError("app_code must be provided.") + params = {"appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) if not data: raise ValueError("No data found.") return WebAppSettings(**data) diff --git a/api/services/enterprise/mail_service.py b/api/services/enterprise/mail_service.py index 24b22008a1..630e7679ac 100644 --- a/api/services/enterprise/mail_service.py +++ b/api/services/enterprise/mail_service.py @@ -1,26 +1,18 @@ - -from typing import Dict, List - from pydantic import BaseModel from tasks.mail_enterprise_task import send_enterprise_email_task class DifyMail(BaseModel): - to: List[str] + to: list[str] subject: str body: str - substitutions: Dict[str, str] = {} + substitutions: dict[str, str] = {} class EnterpriseMailService: - @classmethod def send_mail(cls, mail: DifyMail): - send_enterprise_email_task.delay( - to=mail.to, - subject=mail.subject, - body=mail.body, - substitutions=mail.substitutions + to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions ) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index f37f4e0d92..c38c9cb72b 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -205,8 +205,12 @@ class FeatureService: if "WebAppAuth" in enterprise_info: features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False) - features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False) - features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False) + features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get( + "allowEmailCodeLogin", False + ) + features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( + "allowEmailPasswordLogin", False + ) if "License" in enterprise_info: license_info = enterprise_info["License"] diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 24d1177d87..2f3ef5d97c 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -5,8 +5,7 @@ from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config -from controllers.web.error import (WebAppAuthFailedError, - WebAppAuthRequiredError) +from controllers.web.error import WebAppAuthFailedError from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService @@ -14,8 +13,7 @@ from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService -from services.errors.account import (AccountLoginError, AccountNotFoundError, - AccountPasswordError) +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -99,7 +97,7 @@ class WebAppAuthService: is_anonymous=False, session_id=email, name="enterpriseuser", - external_user_id="enterpriseuser" + external_user_id="enterpriseuser", ) db.session.add(end_user) db.session.commit() @@ -112,9 +110,8 @@ class WebAppAuthService: system_features = FeatureService.get_system_features() if system_features.webapp_auth.enabled: app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) - if not app_settings or not app_settings.access_mode == "public": - raise WebAppAuthRequiredError() - if app_settings.access_mode == "private" and not EnterpriseService.is_user_allowed_to_access_webapp( + + if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp( account.id, app_code=app_code ): raise WebAppAuthFailedError() diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_enterprise_task.py index 67475185db..b9d8fd55df 100644 --- a/api/tasks/mail_enterprise_task.py +++ b/api/tasks/mail_enterprise_task.py @@ -13,9 +13,7 @@ def send_enterprise_email_task(to, subject, body, substitutions): if not mail.is_inited(): return - logging.info( - click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green") - ) + logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")) start_at = time.perf_counter() try: @@ -29,9 +27,7 @@ def send_enterprise_email_task(to, subject, body, substitutions): end_at = time.perf_counter() logging.info( - click.style( - "Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" - ) + click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") ) except Exception: logging.exception("Send enterprise mail to {} failed".format(to)) From cbea30e65f15789ae91cbf3d0ef608f5cb2e8056 Mon Sep 17 00:00:00 2001 From: GareArc Date: Wed, 9 Apr 2025 17:21:16 -0400 Subject: [PATCH 04/14] fix: bad field name --- api/services/feature_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index c38c9cb72b..e811408c41 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -204,7 +204,7 @@ class FeatureService: features.branding.favicon = enterprise_info["Branding"].get("favicon", "") if "WebAppAuth" in enterprise_info: - features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False) + features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False) features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get( "allowEmailCodeLogin", False ) From b922c8c2150094f28caacee44044ea7b724ede16 Mon Sep 17 00:00:00 2001 From: GareArc Date: Thu, 10 Apr 2025 00:36:35 -0400 Subject: [PATCH 05/14] fix: make app private when created --- api/services/app_service.py | 15 ++++++++++++--- api/services/enterprise/enterprise_service.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 1fd7cb5e33..59a917d4fd 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -12,8 +12,10 @@ from core.agent.entities import AgentToolEntity from core.app.features.rate_limiting import RateLimit from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.entities.model_entities import (ModelPropertyKey, + ModelType) +from core.model_runtime.model_providers.__base.large_language_model import \ + LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created @@ -21,8 +23,11 @@ from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, AppModelConfig from models.tools import ApiToolProvider +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService from services.tag_service import TagService -from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task +from tasks.remove_app_and_related_data_task import \ + remove_app_and_related_data_task class AppService: @@ -152,6 +157,10 @@ class AppService: app_was_created.send(app, account=account) + if FeatureService.get_system_features().webapp_auth.enabled: + # update web app setting as private + EnterpriseService.update_app_access_mode(app.id, "private") + return app def get_app(self, app: App) -> App: diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 21e4831715..c1fccd0dff 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -42,3 +42,19 @@ class EnterpriseService: if not data: raise ValueError("No data found.") return WebAppSettings(**data) + + @classmethod + def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool: + if not app_id: + raise ValueError("app_id must be provided.") + if access_mode not in ["public", "private", "private_all"]: + raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + + data = { + "appId": app_id, + "accessMode": access_mode + } + + response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + + return response.get("result", False) From 4105c8ff70db6fb012994aec336a49d24b99acc5 Mon Sep 17 00:00:00 2001 From: GareArc Date: Thu, 10 Apr 2025 06:27:00 -0400 Subject: [PATCH 06/14] fix: bad api call --- api/controllers/web/passport.py | 2 +- api/services/webapp_auth_service.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 3c1f0a415f..3c07b3e87d 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -23,7 +23,7 @@ class PassportResource(Resource): raise Unauthorized("X-App-Code header is missing.") if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) + app_settings = EnterpriseService.get_app_access_mode_by_code(app_code=app_code) if not app_settings or not app_settings.access_mode == "public": raise WebAppAuthRequiredError() diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 2f3ef5d97c..f9dd80a729 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -13,7 +13,8 @@ from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService -from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from services.errors.account import (AccountLoginError, AccountNotFoundError, + AccountPasswordError) from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -109,7 +110,7 @@ class WebAppAuthService: """Check if the user is allowed to access the app.""" system_features = FeatureService.get_system_features() if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) + app_settings = EnterpriseService.get_app_access_mode_by_code(app_code=app_code) if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp( account.id, app_code=app_code From 4785c061a9261b9b50c799dc0a043bce5aa0557e Mon Sep 17 00:00:00 2001 From: GareArc Date: Thu, 10 Apr 2025 15:19:28 -0400 Subject: [PATCH 07/14] feat: add webapp clean up --- api/services/app_service.py | 4 ++++ api/services/enterprise/enterprise_service.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/api/services/app_service.py b/api/services/app_service.py index 59a917d4fd..03393c00fa 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -317,6 +317,10 @@ class AppService: db.session.delete(app) db.session.commit() + # clean up web app settings + if FeatureService.get_system_features().webapp_auth.enabled: + EnterpriseService.cleanup_webapp(app.id) + # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index c1fccd0dff..299764ffc4 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -58,3 +58,13 @@ class EnterpriseService: response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) return response.get("result", False) + + @classmethod + def cleanup_webapp(cls, app_id: str): + if not app_id: + raise ValueError("app_id must be provided.") + + body = { + "appId": app_id + } + EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) From 7a4ec9cf238525d8dd40e222641e46832ccc2a2f Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 02:41:02 -0400 Subject: [PATCH 08/14] fix: change error code for webapp auth --- api/controllers/web/error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4909694d26..45ab93d324 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -122,7 +122,7 @@ class UnsupportedFileTypeError(BaseHTTPException): class WebAppAuthRequiredError(BaseHTTPException): - error_code = "web_auth_required" + error_code = "web_sso_auth_required" description = "Web app authentication required." code = 401 From a1dc3cfdecf5a91c505e707a7f1b1327b3901494 Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 02:45:46 -0400 Subject: [PATCH 09/14] fix: update code for access denied error --- api/controllers/web/error.py | 4 ++-- api/controllers/web/wraps.py | 5 +++-- api/services/webapp_auth_service.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 45ab93d324..4371e679db 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -127,8 +127,8 @@ class WebAppAuthRequiredError(BaseHTTPException): code = 401 -class WebAppAuthFailedError(BaseHTTPException): - error_code = "web_app_auth_failed" +class WebAppAuthAccessDeniedError(BaseHTTPException): + error_code = "web_app_access_denied" description = "You do not have permission to access this web app." code = 401 diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index a009cd3288..5a74296b82 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,7 +4,8 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import WebAppAuthFailedError, WebAppAuthRequiredError +from controllers.web.error import (WebAppAuthAccessDeniedError, + WebAppAuthRequiredError) from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -103,7 +104,7 @@ def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, raise WebAppAuthRequiredError() if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code): - raise WebAppAuthFailedError() + raise WebAppAuthAccessDeniedError() class WebApiResource(Resource): diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index f9dd80a729..6a4a9c795e 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -5,7 +5,7 @@ from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config -from controllers.web.error import WebAppAuthFailedError +from controllers.web.error import WebAppAuthAccessDeniedError from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService @@ -115,7 +115,7 @@ class WebAppAuthService: if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp( account.id, app_code=app_code ): - raise WebAppAuthFailedError() + raise WebAppAuthAccessDeniedError() @classmethod def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: From cb13b53ccd8172e773cf067854291dce61fb4850 Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 03:25:58 -0400 Subject: [PATCH 10/14] fix: update webapp sso features --- api/services/feature_service.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index e811408c41..7575d4101b 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -44,9 +44,14 @@ class BrandingModel(BaseModel): favicon: str = "" +class WebAppAuthSSOModel(BaseModel): + protocol: str = "" + + class WebAppAuthModel(BaseModel): enabled: bool = False allow_sso: bool = False + sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel() allow_email_code_login: bool = False allow_email_password_login: bool = False @@ -71,9 +76,6 @@ class FeatureModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" - sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = "" - enable_web_sso_switch_component: bool = False enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False @@ -107,7 +109,6 @@ class FeatureService: cls._fulfill_system_params_from_env(system_features) if dify_config.ENTERPRISE_ENABLED: - system_features.enable_web_sso_switch_component = True system_features.branding.enabled = True system_features.webapp_auth.enabled = True cls._fulfill_params_from_enterprise(system_features) @@ -170,21 +171,12 @@ class FeatureService: features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] @classmethod - def _fulfill_params_from_enterprise(cls, features): + def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): enterprise_info = EnterpriseService.get_info() if "SSOEnforcedForSignin" in enterprise_info: features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"] - if "SSOEnforcedForSigninProtocol" in enterprise_info: - features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"] - - if "SSOEnforcedForWeb" in enterprise_info: - features.sso_enforced_for_web = enterprise_info["SSOEnforcedForWeb"] - - if "SSOEnforcedForWebProtocol" in enterprise_info: - features.sso_enforced_for_web_protocol = enterprise_info["SSOEnforcedForWebProtocol"] - if "EnableEmailCodeLogin" in enterprise_info: features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"] @@ -211,6 +203,9 @@ class FeatureService: features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( "allowEmailPasswordLogin", False ) + features.webapp_auth.sso_config.protocol = enterprise_info.get( + "SSOEnforcedForSigninProtocol", "" + ) if "License" in enterprise_info: license_info = enterprise_info["License"] From 5f87bdbe3ac2abdd981060312b625bd87babc963 Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 15:24:32 -0400 Subject: [PATCH 11/14] fix: add batch get access mode api --- api/controllers/console/app/app.py | 11 +- api/controllers/web/passport.py | 2 +- api/controllers/web/wraps.py | 9 +- api/services/app_service.py | 13 +-- api/services/enterprise/enterprise_service.py | 107 +++++++++++------- api/services/feature_service.py | 4 +- api/services/webapp_auth_service.py | 10 +- 7 files changed, 87 insertions(+), 69 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ce6da4af79..7ab594eb26 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -66,9 +66,14 @@ class AppListApi(Resource): return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} if FeatureService.get_system_features().webapp_auth.enabled: + app_ids = [str(app.id) for app in app_pagination.items] + res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids) + if len(res) != len(app_ids): + raise BadRequest("Invalid app id in webapp auth") + for app in app_pagination.items: - app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app.id)) - app.access_mode = app_setting.access_mode + if str(app.id) in res: + app.access_mode = res[str(app.id)].access_mode return marshal(app_pagination, app_pagination_fields), 200 @@ -115,7 +120,7 @@ class AppApi(Resource): app_model = app_service.get_app(app_model) if FeatureService.get_system_features().webapp_auth.enabled: - app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app_model.id)) + app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode return app_model diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 3c07b3e87d..8ab9b84574 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -23,7 +23,7 @@ class PassportResource(Resource): raise Unauthorized("X-App-Code header is missing.") if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.get_app_access_mode_by_code(app_code=app_code) + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) if not app_settings or not app_settings.access_mode == "public": raise WebAppAuthRequiredError() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 5a74296b82..8d35b8e4be 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,8 +4,7 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import (WebAppAuthAccessDeniedError, - WebAppAuthRequiredError) +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -62,7 +61,7 @@ def decode_jwt_token(): app_web_auth_enabled = False if system_features.webapp_auth.enabled: app_web_auth_enabled = ( - EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" ) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) @@ -72,7 +71,7 @@ def decode_jwt_token(): except Unauthorized as e: if system_features.webapp_auth.enabled: app_web_auth_enabled = ( - EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" ) if app_web_auth_enabled: raise WebAppAuthRequiredError() @@ -103,7 +102,7 @@ def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, if not user_id: raise WebAppAuthRequiredError() - if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): raise WebAppAuthAccessDeniedError() diff --git a/api/services/app_service.py b/api/services/app_service.py index 03393c00fa..9359bb2844 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -12,10 +12,8 @@ from core.agent.entities import AgentToolEntity from core.app.features.rate_limiting import RateLimit from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import (ModelPropertyKey, - ModelType) -from core.model_runtime.model_providers.__base.large_language_model import \ - LargeLanguageModel +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created @@ -26,8 +24,7 @@ from models.tools import ApiToolProvider from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.tag_service import TagService -from tasks.remove_app_and_related_data_task import \ - remove_app_and_related_data_task +from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task class AppService: @@ -159,7 +156,7 @@ class AppService: if FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private - EnterpriseService.update_app_access_mode(app.id, "private") + EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private") return app @@ -319,7 +316,7 @@ class AppService: # clean up web app settings if FeatureService.get_system_features().webapp_auth.enabled: - EnterpriseService.cleanup_webapp(app.id) + EnterpriseService.WebAppAuth.cleanup_webapp(app.id) # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 299764ffc4..a3e4d163c3 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,3 +1,5 @@ +import logging + from pydantic import BaseModel, Field from services.enterprise.base import EnterpriseRequest @@ -16,55 +18,72 @@ class EnterpriseService: def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") - @classmethod - def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: - params = {"userId": user_id, "appCode": app_code} - data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) + class WebAppAuth: + @classmethod + def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: + params = {"userId": user_id, "appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) - return data.get("result", False) + return data.get("result", False) - @classmethod - def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: - if not app_id: - raise ValueError("app_id must be provided.") - params = {"appId": app_id} - data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) - if not data: - raise ValueError("No data found.") - return WebAppSettings(**data) + @classmethod + def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: + if not app_id: + raise ValueError("app_id must be provided.") + params = {"appId": app_id} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) - @classmethod - def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: - if not app_code: - raise ValueError("app_code must be provided.") - params = {"appCode": app_code} - data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) - if not data: - raise ValueError("No data found.") - return WebAppSettings(**data) + @classmethod + def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: + if not app_ids: + raise ValueError("app_ids must be provided.") + params = {"appIds": ",".join(app_ids)} + data: dict[str, str] = EnterpriseRequest.send_request("GET", "/webapp/access-mode/batch/id", params=params) + if not data: + raise ValueError("No data found.") - @classmethod - def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool: - if not app_id: - raise ValueError("app_id must be provided.") - if access_mode not in ["public", "private", "private_all"]: - raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + logging.info(f"Batch get app access mode by id returns data: {data}") - data = { - "appId": app_id, - "accessMode": access_mode - } + if not isinstance(data, dict): + raise ValueError("Invalid data format.") - response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + for key, value in data.items(): + curr = WebAppSettings() + curr.access_mode = value + data[key] = curr - return response.get("result", False) + return data - @classmethod - def cleanup_webapp(cls, app_id: str): - if not app_id: - raise ValueError("app_id must be provided.") - - body = { - "appId": app_id - } - EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) + @classmethod + def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: + if not app_code: + raise ValueError("app_code must be provided.") + params = {"appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool: + if not app_id: + raise ValueError("app_id must be provided.") + if access_mode not in ["public", "private", "private_all"]: + raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + + data = {"appId": app_id, "accessMode": access_mode} + + response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + + return response.get("result", False) + + @classmethod + def cleanup_webapp(cls, app_id: str): + if not app_id: + raise ValueError("app_id must be provided.") + + body = {"appId": app_id} + EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 7575d4101b..e62a94cc9d 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -203,9 +203,7 @@ class FeatureService: features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( "allowEmailPasswordLogin", False ) - features.webapp_auth.sso_config.protocol = enterprise_info.get( - "SSOEnforcedForSigninProtocol", "" - ) + features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForSigninProtocol", "") if "License" in enterprise_info: license_info = enterprise_info["License"] diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 6a4a9c795e..506b7698e0 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -13,8 +13,7 @@ from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService -from services.errors.account import (AccountLoginError, AccountNotFoundError, - AccountPasswordError) +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -110,10 +109,11 @@ class WebAppAuthService: """Check if the user is allowed to access the app.""" system_features = FeatureService.get_system_features() if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.get_app_access_mode_by_code(app_code=app_code) + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) - if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp( - account.id, app_code=app_code + if ( + app_settings.access_mode != "public" + and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code) ): raise WebAppAuthAccessDeniedError() From d5b75470e447586e11e74b39fe5c1e90750a59d3 Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 16:48:09 -0400 Subject: [PATCH 12/14] fix: bad request --- api/services/enterprise/enterprise_service.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index a3e4d163c3..d00d90d994 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -40,22 +40,22 @@ class EnterpriseService: def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: if not app_ids: raise ValueError("app_ids must be provided.") - params = {"appIds": ",".join(app_ids)} - data: dict[str, str] = EnterpriseRequest.send_request("GET", "/webapp/access-mode/batch/id", params=params) + body = {"appIds": app_ids} + data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body) if not data: raise ValueError("No data found.") - logging.info(f"Batch get app access mode by id returns data: {data}") - if not isinstance(data, dict): + logging.info(f"Batch get app access mode by id returns data: {data}") raise ValueError("Invalid data format.") + ret = {} for key, value in data.items(): curr = WebAppSettings() curr.access_mode = value - data[key] = curr + ret[key] = curr - return data + return ret @classmethod def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: From fa76590c2409ddead4c7909369d549b971ca879a Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 16:59:52 -0400 Subject: [PATCH 13/14] chore: add log --- api/controllers/console/app/app.py | 18 ++++++++++-------- api/services/enterprise/enterprise_service.py | 1 + 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 7ab594eb26..1431dca7b4 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,23 +1,24 @@ +import logging import uuid from typing import cast from flask_login import current_user # type: ignore -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore +from flask_restful import (Resource, inputs, marshal, # type: ignore + marshal_with, reqparse) from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import ( - account_initialization_required, - cloud_edition_billing_resource_check, - enterprise_license_required, - setup_required, -) +from controllers.console.wraps import (account_initialization_required, + cloud_edition_billing_resource_check, + enterprise_license_required, + setup_required) from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db -from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields +from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, + app_pagination_fields) from libs.login import login_required from models import Account, App from services.app_dsl_service import AppDslService, ImportMode @@ -67,6 +68,7 @@ class AppListApi(Resource): if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] + logging.info(f"app_ids: {app_ids}") res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids) if len(res) != len(app_ids): raise BadRequest("Invalid app id in webapp auth") diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index d00d90d994..3e6f9e27e4 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -55,6 +55,7 @@ class EnterpriseService: curr.access_mode = value ret[key] = curr + logging.info(f"Batch get app access mode by id returns data: {ret}") return ret @classmethod From bafdbade526476b21e1780fde41d39d9ae973d6f Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 17:19:34 -0400 Subject: [PATCH 14/14] fix: wrong json structure --- api/services/enterprise/enterprise_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 3e6f9e27e4..e44e7f6658 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -45,12 +45,12 @@ class EnterpriseService: if not data: raise ValueError("No data found.") - if not isinstance(data, dict): + if not isinstance(data['accessModes'], dict): logging.info(f"Batch get app access mode by id returns data: {data}") raise ValueError("Invalid data format.") ret = {} - for key, value in data.items(): + for key, value in data['accessModes'].items(): curr = WebAppSettings() curr.access_mode = value ret[key] = curr