Merge remote-tracking branch 'origin/e-260' into feat/e-priced-limit

# Conflicts:
#	api/services/enterprise/mail_service.py
pull/17683/head
zhangx1n 1 year ago
commit bcb84f563c

@ -1,31 +1,30 @@
import logging
import uuid import uuid
from typing import cast from typing import cast
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore from flask_restful import (Resource, inputs, marshal, # type: ignore
marshal_with, reqparse)
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (account_initialization_required,
account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
enterprise_license_required, enterprise_license_required,
setup_required, setup_required)
)
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import (app_detail_fields, app_detail_fields_with_site,
app_detail_fields, app_pagination_fields)
app_detail_fields_with_site,
app_pagination_fields,
)
from libs.login import login_required from libs.login import login_required
from models import Account, App from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -67,7 +66,18 @@ class AppListApi(Resource):
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
return marshal(app_pagination, app_pagination_fields) if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
logging.info(f"app_ids: {app_ids}")
res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids)
if len(res) != len(app_ids):
raise BadRequest("Invalid app id in webapp auth")
for app in app_pagination.items:
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
return marshal(app_pagination, app_pagination_fields), 200
@setup_required @setup_required
@login_required @login_required
@ -111,6 +121,10 @@ class AppApi(Resource):
app_model = app_service.get_app(app_model) app_model = app_service.get_app(app_model)
if FeatureService.get_system_features().webapp_auth.enabled:
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
return app_model return app_model
@setup_required @setup_required

@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415 code = 415
class WebSSOAuthRequiredError(BaseHTTPException): class WebAppAuthRequiredError(BaseHTTPException):
error_code = "web_sso_auth_required" error_code = "web_sso_auth_required"
description = "Web SSO authentication required." description = "Web app authentication required."
code = 401
class WebAppAuthAccessDeniedError(BaseHTTPException):
error_code = "web_app_access_denied"
description = "You do not have permission to access this web app."
code = 401 code = 401

@ -0,0 +1,121 @@
from flask import request
from flask_restful import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from web import api
from werkzeug.exceptions import BadRequest
import services
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
from controllers.console.error import AccountBannedError, AccountNotFound
from controllers.console.wraps import setup_required
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
from services.webapp_auth_service import WebAppAuthService
class LoginApi(Resource):
"""Resource for web app email/password login."""
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
try:
account = WebAppAuthService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
return {"result": "success", "token": token}
# class LogoutApi(Resource):
# @setup_required
# def get(self):
# account = cast(Account, flask_login.current_user)
# if isinstance(account, flask_login.AnonymousUserMixin):
# return {"result": "success"}
# flask_login.logout_user()
# return {"result": "success"}
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = WebAppAuthService.get_user_through_email(args["email"])
if account is None:
raise AccountNotFound()
else:
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "token": token}
api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

@ -5,7 +5,7 @@ from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api from controllers.web import api
from controllers.web.error import WebSSOAuthRequiredError from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
@ -22,10 +22,10 @@ class PassportResource(Resource):
if app_code is None: if app_code is None:
raise Unauthorized("X-App-Code header is missing.") raise Unauthorized("X-App-Code header is missing.")
if system_features.sso_enforced_for_web: if system_features.webapp_auth.enabled:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if app_web_sso_enabled: if not app_settings or not app_settings.access_mode == "public":
raise WebSSOAuthRequiredError() raise WebAppAuthRequiredError()
# get site from db and check if it is normal # get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebSSOAuthRequiredError from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
@ -57,35 +57,53 @@ def decode_jwt_token():
if not end_user: if not end_user:
raise NotFound() raise NotFound()
_validate_web_sso_token(decoded, system_features, app_code) # for enterprise webapp auth
app_web_auth_enabled = False
if system_features.webapp_auth.enabled:
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
)
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
return app_model, end_user return app_model, end_user
except Unauthorized as e: except Unauthorized as e:
if system_features.sso_enforced_for_web: if system_features.webapp_auth.enabled:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) app_web_auth_enabled = (
if app_web_sso_enabled: EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
raise WebSSOAuthRequiredError() )
if app_web_auth_enabled:
raise WebAppAuthRequiredError()
raise Unauthorized(e.description) raise Unauthorized(e.description)
def _validate_web_sso_token(decoded, system_features, app_code): def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
app_web_sso_enabled = False # Check if authentication is enforced for web app, and if the token source is not webapp,
# raise an error and redirect to login
# Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_webapp_auth_enabled and app_web_auth_enabled:
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
source = decoded.get("token_source") source = decoded.get("token_source")
if not source or source != "sso": if not source or source != "webapp":
raise WebSSOAuthRequiredError() raise WebAppAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO, # Check if authentication is not enforced for web, and if the token source is webapp,
# raise an error and redirect to normal passport login # raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled: if not system_webapp_auth_enabled or not app_web_auth_enabled:
source = decoded.get("token_source") source = decoded.get("token_source")
if source and source == "sso": if source and source == "webapp":
raise Unauthorized("sso token expired.") raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app
user_id = decoded.get("user_id")
if not user_id:
raise WebAppAuthRequiredError()
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
class WebApiResource(Resource): class WebApiResource(Resource):

@ -63,6 +63,7 @@ app_detail_fields = {
"created_at": TimestampField, "created_at": TimestampField,
"updated_by": fields.String, "updated_by": fields.String,
"updated_at": TimestampField, "updated_at": TimestampField,
"access_mode": fields.String,
} }
prompt_config_fields = { prompt_config_fields = {
@ -98,6 +99,7 @@ app_partial_fields = {
"updated_by": fields.String, "updated_by": fields.String,
"updated_at": TimestampField, "updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_fields)), "tags": fields.List(fields.Nested(tag_fields)),
"access_mode": fields.String,
} }
@ -170,6 +172,7 @@ app_detail_fields_with_site = {
"updated_by": fields.String, "updated_by": fields.String,
"updated_at": TimestampField, "updated_at": TimestampField,
"deleted_tools": fields.List(fields.String), "deleted_tools": fields.List(fields.String),
"access_mode": fields.String,
} }
app_site_fields = { app_site_fields = {

@ -21,6 +21,8 @@ from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.tag_service import TagService from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
@ -152,6 +154,10 @@ class AppService:
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
if FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
return app return app
def get_app(self, app: App) -> App: def get_app(self, app: App) -> App:
@ -308,6 +314,10 @@ class AppService:
db.session.delete(app) db.session.delete(app)
db.session.commit() db.session.commit()
# clean up web app settings
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
# Trigger asynchronous deletion of app and related data # Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)

@ -1,11 +1,90 @@
import logging
from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public' or 'private'",
default="private",
alias="accessMode",
)
class EnterpriseService: class EnterpriseService:
@classmethod @classmethod
def get_info(cls, tenant_id=None): def get_info(cls, tenant_id=None):
return EnterpriseRequest.send_request("GET", "/info", tenant_id=tenant_id) return EnterpriseRequest.send_request("GET", "/info", tenant_id=tenant_id)
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool:
params = {"userId": user_id, "appCode": app_code}
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
return data.get("result", False)
@classmethod
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
if not app_id:
raise ValueError("app_id must be provided.")
params = {"appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
if not app_ids:
raise ValueError("app_ids must be provided.")
body = {"appIds": app_ids}
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
if not data:
raise ValueError("No data found.")
if not isinstance(data['accessModes'], dict):
logging.info(f"Batch get app access mode by id returns data: {data}")
raise ValueError("Invalid data format.")
ret = {}
for key, value in data['accessModes'].items():
curr = WebAppSettings()
curr.access_mode = value
ret[key] = curr
logging.info(f"Batch get app access mode by id returns data: {ret}")
return ret
@classmethod @classmethod
def get_app_web_sso_enabled(cls, app_code): def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") if not app_code:
raise ValueError("app_code must be provided.")
params = {"appCode": app_code}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool:
if not app_id:
raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
data = {"appId": app_id, "accessMode": access_mode}
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
return response.get("result", False)
@classmethod
def cleanup_webapp(cls, app_id: str):
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)

@ -1,15 +1,13 @@
from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from tasks.mail_enterprise_task import send_enterprise_email_task from tasks.mail_enterprise_task import send_enterprise_email_task
class DifyMail(BaseModel): class DifyMail(BaseModel):
to: List[str] to: list[str]
subject: str subject: str
body: str body: str
substitutions: Dict[str, str] = {} substitutions: dict[str, str] = {}
class EnterpriseMailService: class EnterpriseMailService:

@ -46,6 +46,18 @@ class BrandingModel(BaseModel):
favicon: str = "" favicon: str = ""
class WebAppAuthSSOModel(BaseModel):
protocol: str = ""
class WebAppAuthModel(BaseModel):
enabled: bool = False
allow_sso: bool = False
sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel()
allow_email_code_login: bool = False
allow_email_password_login: bool = False
class FeatureModel(BaseModel): class FeatureModel(BaseModel):
billing: BillingModel = BillingModel() billing: BillingModel = BillingModel()
members: LimitationModel = LimitationModel(size=0, limit=1) members: LimitationModel = LimitationModel(size=0, limit=1)
@ -67,9 +79,6 @@ class FeatureModel(BaseModel):
class SystemFeatureModel(BaseModel): class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = "" sso_enforced_for_signin_protocol: str = ""
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False
enable_email_code_login: bool = False enable_email_code_login: bool = False
enable_email_password_login: bool = True enable_email_password_login: bool = True
enable_social_oauth_login: bool = False enable_social_oauth_login: bool = False
@ -78,6 +87,7 @@ class SystemFeatureModel(BaseModel):
is_email_setup: bool = False is_email_setup: bool = False
license: LicenseModel = LicenseModel() license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel() branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
class FeatureService: class FeatureService:
@ -103,8 +113,8 @@ class FeatureService:
cls._fulfill_system_params_from_env(system_features) cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED: if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
system_features.branding.enabled = True system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
cls._fulfill_params_from_enterprise(system_features) cls._fulfill_params_from_enterprise(system_features)
return system_features return system_features
@ -171,21 +181,12 @@ class FeatureService:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
@classmethod @classmethod
def _fulfill_params_from_enterprise(cls, features): def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
enterprise_info = EnterpriseService.get_info() enterprise_info = EnterpriseService.get_info()
if "SSOEnforcedForSignin" in enterprise_info: if "SSOEnforcedForSignin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"] features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"]
if "SSOEnforcedForSigninProtocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"]
if "SSOEnforcedForWeb" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["SSOEnforcedForWeb"]
if "SSOEnforcedForWebProtocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["SSOEnforcedForWebProtocol"]
if "EnableEmailCodeLogin" in enterprise_info: if "EnableEmailCodeLogin" in enterprise_info:
features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"] features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"]
@ -204,6 +205,16 @@ class FeatureService:
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "") features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
features.branding.favicon = enterprise_info["Branding"].get("favicon", "") features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
if "WebAppAuth" in enterprise_info:
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False)
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
"allowEmailCodeLogin", False
)
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
"allowEmailPasswordLogin", False
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForSigninProtocol", "")
if "License" in enterprise_info: if "License" in enterprise_info:
license_info = enterprise_info["License"] license_info = enterprise_info["License"]

@ -0,0 +1,137 @@
import random
from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web.error import WebAppAuthAccessDeniedError
from extensions.ext_database import db
from libs.helper import TokenManager
from libs.passport import PassportService
from libs.password import compare_password
from models.account import Account, AccountStatus
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.feature_service import FeatureService
from tasks.mail_email_code_login import send_email_code_login_mail_task
class WebAppAuthService:
"""Service for web app authentication."""
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
return cast(Account, account)
@classmethod
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
return access_token
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
return account
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "webapp_email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "webapp_email_code_login")
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
site = db.session.query(Site).filter(Site.code == app_code).first()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=False,
session_id=email,
name="enterpriseuser",
external_user_id="enterpriseuser",
)
db.session.add(end_user)
db.session.commit()
return end_user
@classmethod
def _validate_user_accessibility(cls, account: Account, app_code: str):
"""Check if the user is allowed to access the app."""
system_features = FeatureService.get_system_features()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if (
app_settings.access_mode != "public"
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
):
raise WebAppAuthAccessDeniedError()
@classmethod
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id,
"end_user_id": end_user_id,
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return token
Loading…
Cancel
Save