|
|
|
|
@ -6,6 +6,7 @@ import requests
|
|
|
|
|
from flask import current_app, redirect, request
|
|
|
|
|
from flask_restful import Resource
|
|
|
|
|
|
|
|
|
|
import services
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from constants.languages import languages
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
@ -13,6 +14,7 @@ from libs.helper import get_remote_ip
|
|
|
|
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
|
|
|
|
from models.account import Account, AccountStatus
|
|
|
|
|
from services.account_service import AccountService, RegisterService, TenantService
|
|
|
|
|
from services.errors.account import AccountNotFound
|
|
|
|
|
|
|
|
|
|
from .. import api
|
|
|
|
|
|
|
|
|
|
@ -69,7 +71,10 @@ class OAuthCallback(Resource):
|
|
|
|
|
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
|
|
|
|
return {"error": "OAuth process failed"}, 400
|
|
|
|
|
|
|
|
|
|
account = _generate_account(provider, user_info)
|
|
|
|
|
try:
|
|
|
|
|
account = _generate_account(provider, user_info)
|
|
|
|
|
except services.errors.account.AccountNotFound as e:
|
|
|
|
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=AccountNotFound")
|
|
|
|
|
# Check account status
|
|
|
|
|
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
|
|
|
|
return {"error": "Account is banned or closed."}, 403
|
|
|
|
|
@ -99,8 +104,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|
|
|
|
# Get account by openid or email.
|
|
|
|
|
account = _get_account_by_openid_or_email(provider, user_info)
|
|
|
|
|
|
|
|
|
|
if not account:
|
|
|
|
|
# Create account
|
|
|
|
|
if not account and dify_config.ALLOW_REGISTER:
|
|
|
|
|
account_name = user_info.name if user_info.name else "Dify"
|
|
|
|
|
account = RegisterService.register(
|
|
|
|
|
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
|
|
|
|
@ -114,6 +118,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|
|
|
|
interface_language = languages[0]
|
|
|
|
|
account.interface_language = interface_language
|
|
|
|
|
db.session.commit()
|
|
|
|
|
else:
|
|
|
|
|
raise AccountNotFound()
|
|
|
|
|
|
|
|
|
|
# Link account
|
|
|
|
|
AccountService.link_account_integrate(provider, user_info.id, account)
|
|
|
|
|
|