修订增加基座认证功能

pull/22807/head
Lynx 1 year ago
parent 03ac2d0f17
commit f6cb2a828c

@ -1,4 +1,6 @@
import logging
import time
import hashlib
from datetime import UTC, datetime
from typing import Optional
@ -14,7 +16,7 @@ from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.oauth import DigitalBaseOAuth, GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
@ -24,6 +26,11 @@ from services.feature_service import FeatureService
from .. import api
BASE_OAUTH_CLIENT_ID="4e2c105294fe46a1862a273ea54f469c"
BASE_OAUTH_CLIENT_SECRET="02f551ecffb244c69f10eb792f37c3c71cbbcc97c43c5a83240416a6f7a0cec1c4"
BASE_OAUTH_URL="http://192.168.0.215:9002/schoolbase"
DIFY_WEB_URL="http://120.46.81.72:1349/apps"
DIFY_WEB_KB_URL="http://120.46.81.72:1349/datasets"
def get_oauth_providers():
with current_app.app_context():
@ -44,7 +51,22 @@ def get_oauth_providers():
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
)
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
# if not dify_config.DIGITAL_BASE_CLIENT_ID or not dify_config.DIGITAL_BASE_CLIENT_SECRET or not dify_config.DIGITAL_BASE_URL:
# digital_base_oauth = None
# else:
# digital_base_oauth = DigitalBaseOAuth(
# client_id=dify_config.DIGITAL_BASE_CLIENT_ID,
# client_secret=dify_config.DIGITAL_BASE_CLIENT_SECRET,
# redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/digitalbase",
# base_url=dify_config.DIGITAL_BASE_URL,
# )
digital_base_oauth = DigitalBaseOAuth(
client_id=BASE_OAUTH_CLIENT_ID,
client_secret=BASE_OAUTH_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/digitalbase",
base_url=BASE_OAUTH_URL,
)
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth, "digitalbase": digital_base_oauth}
return OAUTH_PROVIDERS
@ -67,6 +89,7 @@ class OAuthCallback(Resource):
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
logging.error(f"无效的认证提供方: {provider}")
return {"error": "Invalid provider"}, 400
code = request.args.get("code")
@ -80,7 +103,7 @@ class OAuthCallback(Resource):
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e:
error_text = e.response.text if e.response else str(e)
logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}")
logging.exception(f"OAuth认证过程中发生错误认证提供方: {provider},错误信息: {error_text}")
return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token):
@ -128,8 +151,17 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
kb = request.args.get("kb")
if kb:
return redirect(
# f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
f"{DIFY_WEB_KB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
# f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
f"{DIFY_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
@ -150,7 +182,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
# if not FeatureService.get_system_features().is_allow_create_workspace:
if tenant:
raise WorkSpaceNotAllowedCreateError()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
@ -159,8 +192,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(tenant)
if not account:
if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError()
# if not FeatureService.get_system_features().is_allow_register:
# raise AccountNotFoundError()
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

@ -20,6 +20,7 @@ def download_plugin_pkg(plugin_unique_identifier: str):
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
return []
if len(plugin_ids) == 0:
return []

@ -1,9 +1,14 @@
import hashlib
import time
import urllib.parse
from dataclasses import dataclass
from typing import Optional
import requests
import logging
import json
# Initialize logger for this module
logger = logging.getLogger(__name__)
@dataclass
class OAuthUserInfo:
@ -131,3 +136,95 @@ class GoogleOAuth(OAuth):
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
class DigitalBaseOAuth(OAuth):
_AUTH_URL = "http://1.92.71.188/gzt/login" # 基座登录页地址
_TOKEN_URL = "/oauth2/getTokenByCode"
_USER_INFO_URL = "/oauth2/getUserInfoByToken"
_REFRESH_URL = "/oauth2/refreshSessionByToken"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str, base_url: str):
super().__init__(client_id, client_secret, redirect_uri)
self.base_url = base_url
self.app_key = client_id # 数字基座中 AppKey 等同于 client_id
self.app_secret = client_secret
def get_authorization_url(self, invite_token: Optional[str] = None):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
}
if invite_token:
params["state"] = invite_token
# return f"{self.base_url}/oauth2/authorize?{urllib.parse.urlencode(params)}"
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def _generate_headers(self, body: Optional[dict] = None) -> dict:
timestamp = str(int(time.time() * 1000))
body_length = len(json.dumps(body).encode('utf-8')) if body else 0
content = f"{self.app_key}{timestamp}{body_length}"
# 第一步对AppKey + timestamp + bodyLength做sha256加密
sign = hashlib.sha256(content.encode('utf-8')).hexdigest()
# 第二步对sign + AppSecret做md5加密
open_sign = hashlib.md5(f"{sign}{self.app_secret}".encode('utf-8')).hexdigest()
headers = {
"openAppId": self.app_key,
"openTimestamp": timestamp,
"openSign": open_sign,
"Content-Type": "application/json"
}
# 调试日志 - 中文输出
logger.debug(f"数字基座认证请求头: {headers}")
logger.debug(f"签名生成步骤 - 原始内容: {content}, SHA256签名: {sign}, 最终签名: {open_sign}")
return headers
def get_access_token(self, code: str):
data = {"code": code}
headers = self._generate_headers(data)
response = requests.post(
f"{self.base_url}{self._TOKEN_URL}",
headers=headers,
json=data
)
response_json = response.json()
if response_json.get("retcode") != 0:
raise ValueError(f"Error in DigitalBase OAuth: {response_json.get('errmsg')}")
return response_json["data"]["accessToken"]
def get_raw_user_info(self, token: str):
data = {"accessToken": token}
headers = {
**self._generate_headers(data),
"Authorization": f"Bearer {token}"
}
response = requests.post(
f"{self.base_url}{self._USER_INFO_URL}",
headers=headers,
json={"accessToken": token}
)
response.raise_for_status()
response_json = response.json()
if response_json.get("retcode") != 0:
raise ValueError(f"Error in DigitalBase OAuth: {response_json.get('errmsg')}")
return response_json["data"]
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(
id=raw_info["eduID"],
name=raw_info["name"],
email=f"{raw_info['eduID']}@digitalbase.edu" # 基座可能不提供email使用eduID生成
)

@ -204,18 +204,18 @@ class AccountService:
is_setup: Optional[bool] = False,
) -> Account:
"""create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
from controllers.console.error import AccountNotFound
# if not FeatureService.get_system_features().is_allow_register and not is_setup:
# from controllers.console.error import AccountNotFound
raise AccountNotFound()
# raise AccountNotFound()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
"30 days and is temporarily unavailable for new account registration"
)
)
# if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
# raise AccountRegisterError(
# description=(
# "This email account has been deleted within the past "
# "30 days and is temporarily unavailable for new account registration"
# )
# )
account = Account()
account.email = email
@ -407,8 +407,10 @@ class AccountService:
raise PasswordResetRateLimitExceededError()
code, token = cls.generate_reset_password_token(account_email, account)
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
@ -417,22 +419,6 @@ class AccountService:
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data=additional_data
)
return code, token
@classmethod
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
@ -589,14 +575,14 @@ class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
and not is_setup
and not is_from_dashboard
):
from controllers.console.error import NotAllowedCreateWorkspace
raise NotAllowedCreateWorkspace()
# if (
# not FeatureService.get_system_features().is_allow_create_workspace
# and not is_setup
# and not is_from_dashboard
# ):
# from controllers.console.error import NotAllowedCreateWorkspace
# raise NotAllowedCreateWorkspace()
tenant = Tenant(name=name)
db.session.add(tenant)
@ -619,10 +605,17 @@ class TenantService:
return
"""Create owner tenant if not exist"""
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
raise WorkSpaceNotAllowedCreateError()
# if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
# raise WorkSpaceNotAllowedCreateError()
# TODO 需要补充逻辑根据用户的机构id寻找是否已经创建了tenant如果有则将用户加入该tanent如果没有先创建tanent再将用户加入tanent
# 当前将用户全部加进一个默认tanent
name = "教育大模型应用空间"
if name:
tenant = TenantService.get_tenant_by_name(name)
if tenant is None:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
@ -835,6 +828,11 @@ class TenantService:
db.session.delete(tenant)
db.session.commit()
@staticmethod
def get_tenant_by_name(name: str) -> Optional[Tenant]:
"""Get tenant by name"""
return db.session.query(Tenant).filter(Tenant.name == name).first()
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()

@ -47,7 +47,7 @@ class FeatureModel(BaseModel):
members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
knowledge_rate_limit: int = 10
knowledge_rate_limit: int = 100
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = "standard"
@ -66,18 +66,18 @@ class KnowledgeRateLimitModel(BaseModel):
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
sso_enforced_for_signin: bool = True
sso_enforced_for_signin_protocol: str = "oauth2"
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ""
sso_enforced_for_web_protocol: str = "oauth2"
enable_web_sso_switch_component: bool = False
enable_marketplace: bool = False
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
enable_social_oauth_login: bool = True
is_allow_register: bool = False
is_allow_create_workspace: bool = False
is_allow_create_workspace: bool = True
is_email_setup: bool = False
license: LicenseModel = LicenseModel()
@ -100,7 +100,7 @@ class FeatureService:
if dify_config.BILLING_ENABLED and tenant_id:
knowledge_rate_limit.enabled = True
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
knowledge_rate_limit.limit = limit_info.get("limit", 10)
knowledge_rate_limit.limit = limit_info.get("limit", 100)
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
return knowledge_rate_limit
@ -110,10 +110,10 @@ class FeatureService:
cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
# if dify_config.ENTERPRISE_ENABLED:
# system_features.enable_web_sso_switch_component = True
cls._fulfill_params_from_enterprise(system_features)
# cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
@ -124,9 +124,11 @@ class FeatureService:
def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel):
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
# system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.enable_social_oauth_login = True
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
# system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_allow_create_workspace = True
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
@classmethod
@ -134,7 +136,7 @@ class FeatureService:
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
features.education.enabled = dify_config.EDUCATION_ENABLED
# features.education.enabled = dify_config.EDUCATION_ENABLED
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
@ -181,35 +183,35 @@ class FeatureService:
def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info()
if "sso_enforced_for_signin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
# if "sso_enforced_for_signin" in enterprise_info:
# features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
if "sso_enforced_for_signin_protocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
# if "sso_enforced_for_signin_protocol" in enterprise_info:
# features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
if "sso_enforced_for_web" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
# if "sso_enforced_for_web" in enterprise_info:
# features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
if "sso_enforced_for_web_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
# if "sso_enforced_for_web_protocol" in enterprise_info:
# features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
if "enable_email_code_login" in enterprise_info:
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
# if "enable_email_code_login" in enterprise_info:
# features.enable_email_code_login = enterprise_info["enable_email_code_login"]
if "enable_email_password_login" in enterprise_info:
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
# if "enable_email_password_login" in enterprise_info:
# features.enable_email_password_login = enterprise_info["enable_email_password_login"]
if "is_allow_register" in enterprise_info:
features.is_allow_register = enterprise_info["is_allow_register"]
# if "is_allow_register" in enterprise_info:
# features.is_allow_register = enterprise_info["is_allow_register"]
if "is_allow_create_workspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
# if "is_allow_create_workspace" in enterprise_info:
# features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
if "license" in enterprise_info:
license_info = enterprise_info["license"]
# if "license" in enterprise_info:
# license_info = enterprise_info["license"]
if "status" in license_info:
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
# if "status" in license_info:
# features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if "expired_at" in license_info:
features.license.expired_at = license_info["expired_at"]
# if "expired_at" in license_info:
# features.license.expired_at = license_info["expired_at"]

Loading…
Cancel
Save