Merge remote-tracking branch 'origin/e-260' into feat/e-priced-limit
# Conflicts: # api/services/enterprise/mail_service.pypull/17683/head
commit
bcb84f563c
@ -0,0 +1,121 @@
|
|||||||
|
from flask import request
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from jwt import InvalidTokenError # type: ignore
|
||||||
|
from web import api
|
||||||
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||||
|
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||||
|
from controllers.console.wraps import setup_required
|
||||||
|
from libs.helper import email
|
||||||
|
from libs.password import valid_password
|
||||||
|
from services.account_service import AccountService
|
||||||
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
|
||||||
|
class LoginApi(Resource):
|
||||||
|
"""Resource for web app email/password login."""
|
||||||
|
|
||||||
|
def post(self):
|
||||||
|
"""Authenticate user and login."""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app_code = request.headers.get("X-App-Code")
|
||||||
|
if app_code is None:
|
||||||
|
raise BadRequest("X-App-Code header is missing.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||||
|
except services.errors.account.AccountLoginError:
|
||||||
|
raise AccountBannedError()
|
||||||
|
except services.errors.account.AccountPasswordError:
|
||||||
|
raise EmailOrPasswordMismatchError()
|
||||||
|
except services.errors.account.AccountNotFoundError:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||||
|
|
||||||
|
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
|
||||||
|
|
||||||
|
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||||
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
|
# class LogoutApi(Resource):
|
||||||
|
# @setup_required
|
||||||
|
# def get(self):
|
||||||
|
# account = cast(Account, flask_login.current_user)
|
||||||
|
# if isinstance(account, flask_login.AnonymousUserMixin):
|
||||||
|
# return {"result": "success"}
|
||||||
|
# flask_login.logout_user()
|
||||||
|
# return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
parser.add_argument("language", type=str, required=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||||
|
language = "zh-Hans"
|
||||||
|
else:
|
||||||
|
language = "en-US"
|
||||||
|
|
||||||
|
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||||
|
if account is None:
|
||||||
|
raise AccountNotFound()
|
||||||
|
else:
|
||||||
|
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||||
|
|
||||||
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("code", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("token", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
user_email = args["email"]
|
||||||
|
app_code = request.headers.get("X-App-Code")
|
||||||
|
if app_code is None:
|
||||||
|
raise BadRequest("X-App-Code header is missing.")
|
||||||
|
|
||||||
|
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||||
|
if token_data is None:
|
||||||
|
raise InvalidTokenError()
|
||||||
|
|
||||||
|
if token_data["email"] != args["email"]:
|
||||||
|
raise InvalidEmailError()
|
||||||
|
|
||||||
|
if token_data["code"] != args["code"]:
|
||||||
|
raise EmailCodeError()
|
||||||
|
|
||||||
|
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||||
|
account = WebAppAuthService.get_user_through_email(user_email)
|
||||||
|
if not account:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||||
|
|
||||||
|
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
|
||||||
|
|
||||||
|
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||||
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(LoginApi, "/login")
|
||||||
|
# api.add_resource(LogoutApi, "/logout")
|
||||||
|
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||||
|
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||||
@ -1,11 +1,90 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from services.enterprise.base import EnterpriseRequest
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppSettings(BaseModel):
|
||||||
|
access_mode: str = Field(
|
||||||
|
description="Access mode for the web app. Can be 'public' or 'private'",
|
||||||
|
default="private",
|
||||||
|
alias="accessMode",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseService:
|
class EnterpriseService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_info(cls, tenant_id=None):
|
def get_info(cls, tenant_id=None):
|
||||||
return EnterpriseRequest.send_request("GET", "/info", tenant_id=tenant_id)
|
return EnterpriseRequest.send_request("GET", "/info", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
class WebAppAuth:
|
||||||
|
@classmethod
|
||||||
|
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool:
|
||||||
|
params = {"userId": user_id, "appCode": app_code}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
|
||||||
|
|
||||||
|
return data.get("result", False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("app_id must be provided.")
|
||||||
|
params = {"appId": app_id}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
return WebAppSettings(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
|
||||||
|
if not app_ids:
|
||||||
|
raise ValueError("app_ids must be provided.")
|
||||||
|
body = {"appIds": app_ids}
|
||||||
|
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
|
||||||
|
if not isinstance(data['accessModes'], dict):
|
||||||
|
logging.info(f"Batch get app access mode by id returns data: {data}")
|
||||||
|
raise ValueError("Invalid data format.")
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for key, value in data['accessModes'].items():
|
||||||
|
curr = WebAppSettings()
|
||||||
|
curr.access_mode = value
|
||||||
|
ret[key] = curr
|
||||||
|
|
||||||
|
logging.info(f"Batch get app access mode by id returns data: {ret}")
|
||||||
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_app_web_sso_enabled(cls, app_code):
|
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
|
||||||
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
|
if not app_code:
|
||||||
|
raise ValueError("app_code must be provided.")
|
||||||
|
params = {"appCode": app_code}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
return WebAppSettings(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool:
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("app_id must be provided.")
|
||||||
|
if access_mode not in ["public", "private", "private_all"]:
|
||||||
|
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
||||||
|
|
||||||
|
data = {"appId": app_id, "accessMode": access_mode}
|
||||||
|
|
||||||
|
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
|
||||||
|
|
||||||
|
return response.get("result", False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cleanup_webapp(cls, app_id: str):
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("app_id must be provided.")
|
||||||
|
|
||||||
|
body = {"appId": app_id}
|
||||||
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
||||||
|
|||||||
@ -0,0 +1,137 @@
|
|||||||
|
import random
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.web.error import WebAppAuthAccessDeniedError
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.helper import TokenManager
|
||||||
|
from libs.passport import PassportService
|
||||||
|
from libs.password import compare_password
|
||||||
|
from models.account import Account, AccountStatus
|
||||||
|
from models.model import App, EndUser, Site
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppAuthService:
|
||||||
|
"""Service for web app authentication."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def authenticate(email: str, password: str) -> Account:
|
||||||
|
"""authenticate account with email and password"""
|
||||||
|
|
||||||
|
account = Account.query.filter_by(email=email).first()
|
||||||
|
if not account:
|
||||||
|
raise AccountNotFoundError()
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED.value:
|
||||||
|
raise AccountLoginError("Account is banned.")
|
||||||
|
|
||||||
|
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||||
|
raise AccountPasswordError("Invalid email or password.")
|
||||||
|
|
||||||
|
return cast(Account, account)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
|
||||||
|
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||||
|
if not site:
|
||||||
|
raise NotFound("Site not found.")
|
||||||
|
|
||||||
|
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
|
||||||
|
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_through_email(cls, email: str):
|
||||||
|
account = db.session.query(Account).filter(Account.email == email).first()
|
||||||
|
if not account:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED.value:
|
||||||
|
raise Unauthorized("Account is banned.")
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def send_email_code_login_email(
|
||||||
|
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
|
||||||
|
):
|
||||||
|
email = account.email if account else email
|
||||||
|
if email is None:
|
||||||
|
raise ValueError("Email must be provided.")
|
||||||
|
|
||||||
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||||
|
token = TokenManager.generate_token(
|
||||||
|
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
|
||||||
|
)
|
||||||
|
send_email_code_login_mail_task.delay(
|
||||||
|
language=language,
|
||||||
|
to=account.email if account else email,
|
||||||
|
code=code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||||
|
return TokenManager.get_token_data(token, "webapp_email_code_login")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def revoke_email_code_login_token(cls, token: str):
|
||||||
|
TokenManager.revoke_token(token, "webapp_email_code_login")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_end_user(cls, app_code, email) -> EndUser:
|
||||||
|
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||||
|
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||||
|
end_user = EndUser(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
type="browser",
|
||||||
|
is_anonymous=False,
|
||||||
|
session_id=email,
|
||||||
|
name="enterpriseuser",
|
||||||
|
external_user_id="enterpriseuser",
|
||||||
|
)
|
||||||
|
db.session.add(end_user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_user_accessibility(cls, account: Account, app_code: str):
|
||||||
|
"""Check if the user is allowed to access the app."""
|
||||||
|
system_features = FeatureService.get_system_features()
|
||||||
|
if system_features.webapp_auth.enabled:
|
||||||
|
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||||
|
|
||||||
|
if (
|
||||||
|
app_settings.access_mode != "public"
|
||||||
|
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
|
||||||
|
):
|
||||||
|
raise WebAppAuthAccessDeniedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
|
||||||
|
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24)
|
||||||
|
exp = int(exp_dt.timestamp())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"iss": site.id,
|
||||||
|
"sub": "Web API Passport",
|
||||||
|
"app_id": site.app_id,
|
||||||
|
"app_code": site.code,
|
||||||
|
"user_id": account.id,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"token_source": "webapp",
|
||||||
|
"exp": exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
token: str = PassportService().issue(payload)
|
||||||
|
return token
|
||||||
Loading…
Reference in New Issue