修订增加基座认证功能

pull/22807/head
Lynx 1 year ago
parent 03ac2d0f17
commit f6cb2a828c

@ -1,4 +1,6 @@
import logging import logging
import time
import hashlib
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Optional from typing import Optional
@ -14,7 +16,7 @@ from constants.languages import languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import DigitalBaseOAuth, GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from models import Account
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
@ -24,6 +26,11 @@ from services.feature_service import FeatureService
from .. import api from .. import api
BASE_OAUTH_CLIENT_ID="4e2c105294fe46a1862a273ea54f469c"
BASE_OAUTH_CLIENT_SECRET="02f551ecffb244c69f10eb792f37c3c71cbbcc97c43c5a83240416a6f7a0cec1c4"
BASE_OAUTH_URL="http://192.168.0.215:9002/schoolbase"
DIFY_WEB_URL="http://120.46.81.72:1349/apps"
DIFY_WEB_KB_URL="http://120.46.81.72:1349/datasets"
def get_oauth_providers(): def get_oauth_providers():
with current_app.app_context(): with current_app.app_context():
@ -44,7 +51,22 @@ def get_oauth_providers():
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
) )
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} # if not dify_config.DIGITAL_BASE_CLIENT_ID or not dify_config.DIGITAL_BASE_CLIENT_SECRET or not dify_config.DIGITAL_BASE_URL:
# digital_base_oauth = None
# else:
# digital_base_oauth = DigitalBaseOAuth(
# client_id=dify_config.DIGITAL_BASE_CLIENT_ID,
# client_secret=dify_config.DIGITAL_BASE_CLIENT_SECRET,
# redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/digitalbase",
# base_url=dify_config.DIGITAL_BASE_URL,
# )
digital_base_oauth = DigitalBaseOAuth(
client_id=BASE_OAUTH_CLIENT_ID,
client_secret=BASE_OAUTH_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/digitalbase",
base_url=BASE_OAUTH_URL,
)
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth, "digitalbase": digital_base_oauth}
return OAUTH_PROVIDERS return OAUTH_PROVIDERS
@ -67,6 +89,7 @@ class OAuthCallback(Resource):
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider) oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider: if not oauth_provider:
logging.error(f"无效的认证提供方: {provider}")
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
code = request.args.get("code") code = request.args.get("code")
@ -80,7 +103,7 @@ class OAuthCallback(Resource):
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
error_text = e.response.text if e.response else str(e) error_text = e.response.text if e.response else str(e)
logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") logging.exception(f"OAuth认证过程中发生错误认证提供方: {provider},错误信息: {error_text}")
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token): if invite_token and RegisterService.is_valid_invite_token(invite_token):
@ -128,8 +151,17 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),
) )
kb = request.args.get("kb")
if kb:
return redirect(
# f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
f"{DIFY_WEB_KB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
return redirect( return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" # f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
f"{DIFY_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
) )
@ -150,7 +182,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if account: if account:
tenant = TenantService.get_join_tenants(account) tenant = TenantService.get_join_tenants(account)
if not tenant: if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace: # if not FeatureService.get_system_features().is_allow_create_workspace:
if tenant:
raise WorkSpaceNotAllowedCreateError() raise WorkSpaceNotAllowedCreateError()
else: else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
@ -159,8 +192,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
if not account: if not account:
if not FeatureService.get_system_features().is_allow_register: # if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError() # raise AccountNotFoundError()
account_name = user_info.name or "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

@ -20,6 +20,7 @@ def download_plugin_pkg(plugin_unique_identifier: str):
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
return []
if len(plugin_ids) == 0: if len(plugin_ids) == 0:
return [] return []

@ -1,9 +1,14 @@
import hashlib
import time
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import requests import requests
import logging
import json
# Initialize logger for this module
logger = logging.getLogger(__name__)
@dataclass @dataclass
class OAuthUserInfo: class OAuthUserInfo:
@ -131,3 +136,95 @@ class GoogleOAuth(OAuth):
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
class DigitalBaseOAuth(OAuth):
_AUTH_URL = "http://1.92.71.188/gzt/login" # 基座登录页地址
_TOKEN_URL = "/oauth2/getTokenByCode"
_USER_INFO_URL = "/oauth2/getUserInfoByToken"
_REFRESH_URL = "/oauth2/refreshSessionByToken"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str, base_url: str):
super().__init__(client_id, client_secret, redirect_uri)
self.base_url = base_url
self.app_key = client_id # 数字基座中 AppKey 等同于 client_id
self.app_secret = client_secret
def get_authorization_url(self, invite_token: Optional[str] = None):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
}
if invite_token:
params["state"] = invite_token
# return f"{self.base_url}/oauth2/authorize?{urllib.parse.urlencode(params)}"
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def _generate_headers(self, body: Optional[dict] = None) -> dict:
timestamp = str(int(time.time() * 1000))
body_length = len(json.dumps(body).encode('utf-8')) if body else 0
content = f"{self.app_key}{timestamp}{body_length}"
# 第一步对AppKey + timestamp + bodyLength做sha256加密
sign = hashlib.sha256(content.encode('utf-8')).hexdigest()
# 第二步对sign + AppSecret做md5加密
open_sign = hashlib.md5(f"{sign}{self.app_secret}".encode('utf-8')).hexdigest()
headers = {
"openAppId": self.app_key,
"openTimestamp": timestamp,
"openSign": open_sign,
"Content-Type": "application/json"
}
# 调试日志 - 中文输出
logger.debug(f"数字基座认证请求头: {headers}")
logger.debug(f"签名生成步骤 - 原始内容: {content}, SHA256签名: {sign}, 最终签名: {open_sign}")
return headers
def get_access_token(self, code: str):
data = {"code": code}
headers = self._generate_headers(data)
response = requests.post(
f"{self.base_url}{self._TOKEN_URL}",
headers=headers,
json=data
)
response_json = response.json()
if response_json.get("retcode") != 0:
raise ValueError(f"Error in DigitalBase OAuth: {response_json.get('errmsg')}")
return response_json["data"]["accessToken"]
def get_raw_user_info(self, token: str):
data = {"accessToken": token}
headers = {
**self._generate_headers(data),
"Authorization": f"Bearer {token}"
}
response = requests.post(
f"{self.base_url}{self._USER_INFO_URL}",
headers=headers,
json={"accessToken": token}
)
response.raise_for_status()
response_json = response.json()
if response_json.get("retcode") != 0:
raise ValueError(f"Error in DigitalBase OAuth: {response_json.get('errmsg')}")
return response_json["data"]
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(
id=raw_info["eduID"],
name=raw_info["name"],
email=f"{raw_info['eduID']}@digitalbase.edu" # 基座可能不提供email使用eduID生成
)

@ -204,18 +204,18 @@ class AccountService:
is_setup: Optional[bool] = False, is_setup: Optional[bool] = False,
) -> Account: ) -> Account:
"""create account""" """create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup: # if not FeatureService.get_system_features().is_allow_register and not is_setup:
from controllers.console.error import AccountNotFound # from controllers.console.error import AccountNotFound
raise AccountNotFound() # raise AccountNotFound()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): # if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
raise AccountRegisterError( # raise AccountRegisterError(
description=( # description=(
"This email account has been deleted within the past " # "This email account has been deleted within the past "
"30 days and is temporarily unavailable for new account registration" # "30 days and is temporarily unavailable for new account registration"
) # )
) # )
account = Account() account = Account()
account.email = email account.email = email
@ -407,8 +407,10 @@ class AccountService:
raise PasswordResetRateLimitExceededError() raise PasswordResetRateLimitExceededError()
code, token = cls.generate_reset_password_token(account_email, account) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay( send_reset_password_mail_task.delay(
language=language, language=language,
to=account_email, to=account_email,
@ -417,22 +419,6 @@ class AccountService:
cls.reset_password_rate_limiter.increment_rate_limit(account_email) cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token return token
@classmethod
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data=additional_data
)
return code, token
@classmethod @classmethod
def revoke_reset_password_token(cls, token: str): def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password") TokenManager.revoke_token(token, "reset_password")
@ -589,14 +575,14 @@ class TenantService:
@staticmethod @staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
"""Create tenant""" """Create tenant"""
if ( # if (
not FeatureService.get_system_features().is_allow_create_workspace # not FeatureService.get_system_features().is_allow_create_workspace
and not is_setup # and not is_setup
and not is_from_dashboard # and not is_from_dashboard
): # ):
from controllers.console.error import NotAllowedCreateWorkspace # from controllers.console.error import NotAllowedCreateWorkspace
raise NotAllowedCreateWorkspace() # raise NotAllowedCreateWorkspace()
tenant = Tenant(name=name) tenant = Tenant(name=name)
db.session.add(tenant) db.session.add(tenant)
@ -619,10 +605,17 @@ class TenantService:
return return
"""Create owner tenant if not exist""" """Create owner tenant if not exist"""
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: # if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
raise WorkSpaceNotAllowedCreateError() # raise WorkSpaceNotAllowedCreateError()
# TODO 需要补充逻辑根据用户的机构id寻找是否已经创建了tenant如果有则将用户加入该tanent如果没有先创建tanent再将用户加入tanent
# 当前将用户全部加进一个默认tanent
name = "教育大模型应用空间"
if name: if name:
tenant = TenantService.get_tenant_by_name(name)
if tenant is None:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup) tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else: else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
@ -835,6 +828,11 @@ class TenantService:
db.session.delete(tenant) db.session.delete(tenant)
db.session.commit() db.session.commit()
@staticmethod
def get_tenant_by_name(name: str) -> Optional[Tenant]:
"""Get tenant by name"""
return db.session.query(Tenant).filter(Tenant.name == name).first()
@staticmethod @staticmethod
def get_custom_config(tenant_id: str) -> dict: def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()

@ -47,7 +47,7 @@ class FeatureModel(BaseModel):
members: LimitationModel = LimitationModel(size=0, limit=1) members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10) apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5) vector_space: LimitationModel = LimitationModel(size=0, limit=5)
knowledge_rate_limit: int = 10 knowledge_rate_limit: int = 100
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = "standard" docs_processing: str = "standard"
@ -66,18 +66,18 @@ class KnowledgeRateLimitModel(BaseModel):
class SystemFeatureModel(BaseModel): class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False sso_enforced_for_signin: bool = True
sso_enforced_for_signin_protocol: str = "" sso_enforced_for_signin_protocol: str = "oauth2"
sso_enforced_for_web: bool = False sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = "" sso_enforced_for_web_protocol: str = "oauth2"
enable_web_sso_switch_component: bool = False enable_web_sso_switch_component: bool = False
enable_marketplace: bool = False enable_marketplace: bool = False
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
enable_email_code_login: bool = False enable_email_code_login: bool = False
enable_email_password_login: bool = True enable_email_password_login: bool = True
enable_social_oauth_login: bool = False enable_social_oauth_login: bool = True
is_allow_register: bool = False is_allow_register: bool = False
is_allow_create_workspace: bool = False is_allow_create_workspace: bool = True
is_email_setup: bool = False is_email_setup: bool = False
license: LicenseModel = LicenseModel() license: LicenseModel = LicenseModel()
@ -100,7 +100,7 @@ class FeatureService:
if dify_config.BILLING_ENABLED and tenant_id: if dify_config.BILLING_ENABLED and tenant_id:
knowledge_rate_limit.enabled = True knowledge_rate_limit.enabled = True
limit_info = BillingService.get_knowledge_rate_limit(tenant_id) limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
knowledge_rate_limit.limit = limit_info.get("limit", 10) knowledge_rate_limit.limit = limit_info.get("limit", 100)
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox") knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
return knowledge_rate_limit return knowledge_rate_limit
@ -110,10 +110,10 @@ class FeatureService:
cls._fulfill_system_params_from_env(system_features) cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED: # if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True # system_features.enable_web_sso_switch_component = True
cls._fulfill_params_from_enterprise(system_features) # cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED: if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True system_features.enable_marketplace = True
@ -124,9 +124,11 @@ class FeatureService:
def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel): def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel):
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN # system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.enable_social_oauth_login = True
system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE # system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_allow_create_workspace = True
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
@classmethod @classmethod
@ -134,7 +136,7 @@ class FeatureService:
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
features.education.enabled = dify_config.EDUCATION_ENABLED # features.education.enabled = dify_config.EDUCATION_ENABLED
@classmethod @classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
@ -181,35 +183,35 @@ class FeatureService:
def _fulfill_params_from_enterprise(cls, features): def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info() enterprise_info = EnterpriseService.get_info()
if "sso_enforced_for_signin" in enterprise_info: # if "sso_enforced_for_signin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] # features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
if "sso_enforced_for_signin_protocol" in enterprise_info: # if "sso_enforced_for_signin_protocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] # features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
if "sso_enforced_for_web" in enterprise_info: # if "sso_enforced_for_web" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] # features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
if "sso_enforced_for_web_protocol" in enterprise_info: # if "sso_enforced_for_web_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] # features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
if "enable_email_code_login" in enterprise_info: # if "enable_email_code_login" in enterprise_info:
features.enable_email_code_login = enterprise_info["enable_email_code_login"] # features.enable_email_code_login = enterprise_info["enable_email_code_login"]
if "enable_email_password_login" in enterprise_info: # if "enable_email_password_login" in enterprise_info:
features.enable_email_password_login = enterprise_info["enable_email_password_login"] # features.enable_email_password_login = enterprise_info["enable_email_password_login"]
if "is_allow_register" in enterprise_info: # if "is_allow_register" in enterprise_info:
features.is_allow_register = enterprise_info["is_allow_register"] # features.is_allow_register = enterprise_info["is_allow_register"]
if "is_allow_create_workspace" in enterprise_info: # if "is_allow_create_workspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] # features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
if "license" in enterprise_info: # if "license" in enterprise_info:
license_info = enterprise_info["license"] # license_info = enterprise_info["license"]
if "status" in license_info: # if "status" in license_info:
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) # features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if "expired_at" in license_info: # if "expired_at" in license_info:
features.license.expired_at = license_info["expired_at"] # features.license.expired_at = license_info["expired_at"]

Loading…
Cancel
Save