feat: add oauth account not found

pull/8487/head
Joe 2 years ago
parent eadf75ad24
commit 955e2871f4

@ -6,6 +6,7 @@ import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource
import services
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from extensions.ext_database import db from extensions.ext_database import db
@ -13,6 +14,7 @@ from libs.helper import get_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFound
from .. import api from .. import api
@ -69,7 +71,10 @@ class OAuthCallback(Resource):
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
account = _generate_account(provider, user_info) try:
account = _generate_account(provider, user_info)
except services.errors.account.AccountNotFound as e:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=AccountNotFound")
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
@ -99,8 +104,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email. # Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info) account = _get_account_by_openid_or_email(provider, user_info)
if not account: if not account and dify_config.ALLOW_REGISTER:
# Create account
account_name = user_info.name if user_info.name else "Dify" account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
@ -114,6 +118,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
interface_language = languages[0] interface_language = languages[0]
account.interface_language = interface_language account.interface_language = interface_language
db.session.commit() db.session.commit()
else:
raise AccountNotFound()
# Link account # Link account
AccountService.link_account_integrate(provider, user_info.id, account) AccountService.link_account_integrate(provider, user_info.id, account)

@ -23,6 +23,7 @@ from models.model import DifySetup
from services.errors.account import ( from services.errors.account import (
AccountAlreadyInTenantError, AccountAlreadyInTenantError,
AccountLoginError, AccountLoginError,
AccountNotFound,
AccountNotLinkTenantError, AccountNotLinkTenantError,
AccountRegisterError, AccountRegisterError,
CannotOperateSelfError, CannotOperateSelfError,
@ -92,7 +93,7 @@ class AccountService:
account = Account.query.filter_by(email=email).first() account = Account.query.filter_by(email=email).first()
if not account: if not account:
raise AccountLoginError("Invalid email or password.") raise AccountNotFound()
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise AccountLoginError("Account is banned or closed.") raise AccountLoginError("Account is banned or closed.")
@ -330,6 +331,9 @@ class TenantService:
@staticmethod @staticmethod
def create_owner_tenant_if_not_exist(account: Account): def create_owner_tenant_if_not_exist(account: Account):
"""Create owner tenant if not exist""" """Create owner tenant if not exist"""
if not dify_config.ALLOW_CREATE_WORKSPACE:
raise Unauthorized("Create workspace is not allowed.")
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
) )

Loading…
Cancel
Save