Feat/enterprise sso (#3602)
parent
d9f1a8ce9f
commit
4481906be2
@ -0,0 +1,59 @@
|
||||
from flask import current_app, redirect
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from services.enterprise.enterprise_sso_service import EnterpriseSSOService
|
||||
|
||||
|
||||
class EnterpriseSSOSamlLogin(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
return EnterpriseSSOService.get_sso_saml_login()
|
||||
|
||||
|
||||
class EnterpriseSSOSamlAcs(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('SAMLResponse', type=str, required=True, location='form')
|
||||
args = parser.parse_args()
|
||||
saml_response = args['SAMLResponse']
|
||||
|
||||
try:
|
||||
token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||
except Exception as e:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||
|
||||
|
||||
class EnterpriseSSOOidcLogin(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
return EnterpriseSSOService.get_sso_oidc_login()
|
||||
|
||||
|
||||
class EnterpriseSSOOidcCallback(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('state', type=str, required=True, location='args')
|
||||
parser.add_argument('code', type=str, required=True, location='args')
|
||||
parser.add_argument('oidc-state', type=str, required=True, location='cookies')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
token = EnterpriseSSOService.get_sso_oidc_callback(args)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||
except Exception as e:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||
|
||||
|
||||
api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
|
||||
api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
|
||||
api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
|
||||
api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')
|
||||
@ -0,0 +1,8 @@
|
||||
from flask import Blueprint
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .workspace import workspace
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class EnterpriseWorkspace(Resource):
|
||||
|
||||
@setup_required
|
||||
@inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('owner_email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = Account.query.filter_by(email=args['owner_email']).first()
|
||||
if account is None:
|
||||
return {
|
||||
'message': 'owner account not found.'
|
||||
}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args['name'])
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return {
|
||||
'message': 'enterprise workspace created.'
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')
|
||||
@ -0,0 +1,61 @@
|
||||
from base64 import b64encode
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
from flask import abort, current_app, request
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
authorization = request.headers.get('Authorization')
|
||||
if not authorization:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
parts = authorization.split(':')
|
||||
if len(parts) != 2:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
user_id, token = parts
|
||||
if ' ' in user_id:
|
||||
user_id = user_id.split(' ')[1]
|
||||
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
|
||||
data_to_sign = f'DIFY {user_id}'
|
||||
|
||||
signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
|
||||
signature = b64encode(signature.digest()).decode('utf-8')
|
||||
|
||||
if signature != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class EnterpriseRequest:
|
||||
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
|
||||
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
|
||||
|
||||
@classmethod
|
||||
def send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Enterprise-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
return response.json()
|
||||
@ -0,0 +1,28 @@
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class EnterpriseFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ''
|
||||
|
||||
|
||||
class EnterpriseFeatureService:
|
||||
|
||||
@classmethod
|
||||
def get_enterprise_features(cls) -> EnterpriseFeatureModel:
|
||||
features = EnterpriseFeatureModel()
|
||||
|
||||
if current_app.config['ENTERPRISE_ENABLED']:
|
||||
cls._fulfill_params_from_enterprise(features)
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
|
||||
@ -0,0 +1,8 @@
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request('GET', '/info')
|
||||
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseSSOService:
|
||||
|
||||
@classmethod
|
||||
def get_sso_saml_login(cls) -> str:
|
||||
return EnterpriseRequest.send_request('GET', '/sso/saml/login')
|
||||
|
||||
@classmethod
|
||||
def post_sso_saml_acs(cls, saml_response: str) -> str:
|
||||
response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
|
||||
if 'email' not in response or response['email'] is None:
|
||||
logger.exception(response)
|
||||
raise Exception('Saml response is invalid')
|
||||
|
||||
return cls.login_with_email(response.get('email'))
|
||||
|
||||
@classmethod
|
||||
def get_sso_oidc_login(cls):
|
||||
return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
|
||||
|
||||
@classmethod
|
||||
def get_sso_oidc_callback(cls, args: dict):
|
||||
state_from_query = args['state']
|
||||
code_from_query = args['code']
|
||||
state_from_cookies = args['oidc-state']
|
||||
|
||||
if state_from_cookies != state_from_query:
|
||||
raise Exception('invalid state or code')
|
||||
|
||||
response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
|
||||
if 'email' not in response or response['email'] is None:
|
||||
logger.exception(response)
|
||||
raise Exception('OIDC response is invalid')
|
||||
|
||||
return cls.login_with_email(response.get('email'))
|
||||
|
||||
@classmethod
|
||||
def login_with_email(cls, email: str) -> str:
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
if account is None:
|
||||
raise Exception('account not found, please contact system admin to invite you to join in a workspace')
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Exception('account is banned, please contact system admin')
|
||||
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
|
||||
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
return token
|
||||
@ -0,0 +1,87 @@
|
||||
'use client'
|
||||
import cn from 'classnames'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import type { FC } from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
type EnterpriseSSOFormProps = {
|
||||
protocol: string
|
||||
}
|
||||
|
||||
const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
|
||||
protocol,
|
||||
}) => {
|
||||
const searchParams = useSearchParams()
|
||||
const consoleToken = searchParams.get('console_token')
|
||||
const message = searchParams.get('message')
|
||||
|
||||
const router = useRouter()
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
if (consoleToken) {
|
||||
localStorage.setItem('console_token', consoleToken)
|
||||
router.replace('/apps')
|
||||
}
|
||||
|
||||
if (message) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message,
|
||||
})
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleSSOLogin = () => {
|
||||
setIsLoading(true)
|
||||
if (protocol === 'saml') {
|
||||
getSAMLSSOUrl().then((res) => {
|
||||
router.push(res.url)
|
||||
}).finally(() => {
|
||||
setIsLoading(false)
|
||||
})
|
||||
}
|
||||
else {
|
||||
getOIDCSSOUrl().then((res) => {
|
||||
document.cookie = `oidc-state=${res.state}`
|
||||
router.push(res.url)
|
||||
}).finally(() => {
|
||||
setIsLoading(false)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={
|
||||
cn(
|
||||
'flex flex-col items-center w-full grow items-center justify-center',
|
||||
'px-6',
|
||||
'md:px-[108px]',
|
||||
)
|
||||
}>
|
||||
<div className='flex flex-col md:w-[400px]'>
|
||||
<div className="w-full mx-auto">
|
||||
<h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2>
|
||||
</div>
|
||||
<div className="w-full mx-auto mt-10">
|
||||
<Button
|
||||
tabIndex={0}
|
||||
type='primary'
|
||||
onClick={() => { handleSSOLogin() }}
|
||||
disabled={isLoading}
|
||||
className="w-full !fone-medium !text-sm"
|
||||
>{t('login.sso')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default EnterpriseSSOForm
|
||||
@ -0,0 +1,14 @@
|
||||
import { get } from './base'
|
||||
import type { EnterpriseFeatures } from '@/types/enterprise'
|
||||
|
||||
export const getEnterpriseFeatures = () => {
|
||||
return get<EnterpriseFeatures>('/enterprise-features')
|
||||
}
|
||||
|
||||
export const getSAMLSSOUrl = () => {
|
||||
return get<{ url: string }>('/enterprise/sso/saml/login')
|
||||
}
|
||||
|
||||
export const getOIDCSSOUrl = () => {
|
||||
return get<{ url: string; state: string }>('/enterprise/sso/oidc/login')
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
export type EnterpriseFeatures = {
|
||||
sso_enforced_for_signin: boolean
|
||||
sso_enforced_for_signin_protocol: string
|
||||
}
|
||||
|
||||
export const defaultEnterpriseFeatures: EnterpriseFeatures = {
|
||||
sso_enforced_for_signin: false,
|
||||
sso_enforced_for_signin_protocol: '',
|
||||
}
|
||||
Loading…
Reference in New Issue