diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 33bafbf463..89b0df4454 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,4 +1,6 @@ import logging +import time +import hashlib from datetime import UTC, datetime from typing import Optional @@ -14,7 +16,7 @@ from constants.languages import languages from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip -from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.oauth import DigitalBaseOAuth, GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService @@ -24,6 +26,11 @@ from services.feature_service import FeatureService from .. import api +BASE_OAUTH_CLIENT_ID="4e2c105294fe46a1862a273ea54f469c" +BASE_OAUTH_CLIENT_SECRET="02f551ecffb244c69f10eb792f37c3c71cbbcc97c43c5a83240416a6f7a0cec1c4" +BASE_OAUTH_URL="http://192.168.0.215:9002/schoolbase" +DIFY_WEB_URL="http://120.46.81.72:1349/apps" +DIFY_WEB_KB_URL="http://120.46.81.72:1349/datasets" def get_oauth_providers(): with current_app.app_context(): @@ -44,7 +51,22 @@ def get_oauth_providers(): redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} + # if not dify_config.DIGITAL_BASE_CLIENT_ID or not dify_config.DIGITAL_BASE_CLIENT_SECRET or not dify_config.DIGITAL_BASE_URL: + # digital_base_oauth = None + # else: + # digital_base_oauth = DigitalBaseOAuth( + # client_id=dify_config.DIGITAL_BASE_CLIENT_ID, + # client_secret=dify_config.DIGITAL_BASE_CLIENT_SECRET, + # redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/digitalbase", + # base_url=dify_config.DIGITAL_BASE_URL, + # ) + digital_base_oauth = DigitalBaseOAuth( + client_id=BASE_OAUTH_CLIENT_ID, + client_secret=BASE_OAUTH_CLIENT_SECRET, + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/digitalbase", + base_url=BASE_OAUTH_URL, + ) + OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth, "digitalbase": digital_base_oauth} return OAUTH_PROVIDERS @@ -67,6 +89,7 @@ class OAuthCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: + logging.error(f"无效的认证提供方: {provider}") return {"error": "Invalid provider"}, 400 code = request.args.get("code") @@ -80,7 +103,7 @@ class OAuthCallback(Resource): user_info = oauth_provider.get_user_info(token) except requests.exceptions.RequestException as e: error_text = e.response.text if e.response else str(e) - logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") + logging.exception(f"OAuth认证过程中发生错误,认证提供方: {provider},错误信息: {error_text}") return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): @@ -127,9 +150,18 @@ class OAuthCallback(Resource): account=account, ip_address=extract_remote_ip(request), ) + + kb = request.args.get("kb") + if kb: + return redirect( + # f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + f"{DIFY_WEB_KB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + ) + return redirect( - f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + # f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + f"{DIFY_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" ) @@ -150,7 +182,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if account: tenant = TenantService.get_join_tenants(account) if not tenant: - if not FeatureService.get_system_features().is_allow_create_workspace: + # if not FeatureService.get_system_features().is_allow_create_workspace: + if tenant: raise WorkSpaceNotAllowedCreateError() else: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") @@ -159,8 +192,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): tenant_was_created.send(tenant) if not account: - if not FeatureService.get_system_features().is_allow_register: - raise AccountNotFoundError() + # if not FeatureService.get_system_features().is_allow_register: + # raise AccountNotFoundError() account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index f4129b88ed..90be5ef879 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -20,6 +20,7 @@ def download_plugin_pkg(plugin_unique_identifier: str): def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: + return [] if len(plugin_ids) == 0: return [] diff --git a/api/libs/oauth.py b/api/libs/oauth.py index df75b55019..73f6b247ea 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,9 +1,14 @@ +import hashlib +import time import urllib.parse from dataclasses import dataclass from typing import Optional import requests - +import logging +import json +# Initialize logger for this module +logger = logging.getLogger(__name__) @dataclass class OAuthUserInfo: @@ -131,3 +136,95 @@ class GoogleOAuth(OAuth): def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) + + +class DigitalBaseOAuth(OAuth): + _AUTH_URL = "http://1.92.71.188/gzt/login" # 基座登录页地址 + _TOKEN_URL = "/oauth2/getTokenByCode" + _USER_INFO_URL = "/oauth2/getUserInfoByToken" + _REFRESH_URL = "/oauth2/refreshSessionByToken" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str, base_url: str): + super().__init__(client_id, client_secret, redirect_uri) + self.base_url = base_url + self.app_key = client_id # 数字基座中 AppKey 等同于 client_id + self.app_secret = client_secret + + def get_authorization_url(self, invite_token: Optional[str] = None): + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "response_type": "code", + } + if invite_token: + params["state"] = invite_token + # return f"{self.base_url}/oauth2/authorize?{urllib.parse.urlencode(params)}" + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def _generate_headers(self, body: Optional[dict] = None) -> dict: + timestamp = str(int(time.time() * 1000)) + body_length = len(json.dumps(body).encode('utf-8')) if body else 0 + content = f"{self.app_key}{timestamp}{body_length}" + + # 第一步:对AppKey + timestamp + bodyLength做sha256加密 + sign = hashlib.sha256(content.encode('utf-8')).hexdigest() + + # 第二步:对sign + AppSecret做md5加密 + open_sign = hashlib.md5(f"{sign}{self.app_secret}".encode('utf-8')).hexdigest() + + headers = { + "openAppId": self.app_key, + "openTimestamp": timestamp, + "openSign": open_sign, + "Content-Type": "application/json" + } + + # 调试日志 - 中文输出 + logger.debug(f"数字基座认证请求头: {headers}") + logger.debug(f"签名生成步骤 - 原始内容: {content}, SHA256签名: {sign}, 最终签名: {open_sign}") + + return headers + + def get_access_token(self, code: str): + data = {"code": code} + headers = self._generate_headers(data) + + response = requests.post( + f"{self.base_url}{self._TOKEN_URL}", + headers=headers, + json=data + ) + response_json = response.json() + + + if response_json.get("retcode") != 0: + raise ValueError(f"Error in DigitalBase OAuth: {response_json.get('errmsg')}") + + return response_json["data"]["accessToken"] + + def get_raw_user_info(self, token: str): + data = {"accessToken": token} + headers = { + **self._generate_headers(data), + "Authorization": f"Bearer {token}" + } + + response = requests.post( + f"{self.base_url}{self._USER_INFO_URL}", + headers=headers, + json={"accessToken": token} + ) + response.raise_for_status() + response_json = response.json() + + if response_json.get("retcode") != 0: + raise ValueError(f"Error in DigitalBase OAuth: {response_json.get('errmsg')}") + + return response_json["data"] + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + return OAuthUserInfo( + id=raw_info["eduID"], + name=raw_info["name"], + email=f"{raw_info['eduID']}@digitalbase.edu" # 基座可能不提供email,使用eduID生成 + ) diff --git a/api/services/account_service.py b/api/services/account_service.py index f930ef910b..3534471bb1 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -204,18 +204,18 @@ class AccountService: is_setup: Optional[bool] = False, ) -> Account: """create account""" - if not FeatureService.get_system_features().is_allow_register and not is_setup: - from controllers.console.error import AccountNotFound + # if not FeatureService.get_system_features().is_allow_register and not is_setup: + # from controllers.console.error import AccountNotFound - raise AccountNotFound() + # raise AccountNotFound() - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): - raise AccountRegisterError( - description=( - "This email account has been deleted within the past " - "30 days and is temporarily unavailable for new account registration" - ) - ) + # if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): + # raise AccountRegisterError( + # description=( + # "This email account has been deleted within the past " + # "30 days and is temporarily unavailable for new account registration" + # ) + # ) account = Account() account.email = email @@ -407,8 +407,10 @@ class AccountService: raise PasswordResetRateLimitExceededError() - code, token = cls.generate_reset_password_token(account_email, account) - + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data={"code": code} + ) send_reset_password_mail_task.delay( language=language, to=account_email, @@ -417,22 +419,6 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token - @classmethod - def generate_reset_password_token( - cls, - email: str, - account: Optional[Account] = None, - code: Optional[str] = None, - additional_data: dict[str, Any] = {}, - ): - if not code: - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - additional_data["code"] = code - token = TokenManager.generate_token( - account=account, email=email, token_type="reset_password", additional_data=additional_data - ) - return code, token - @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") @@ -589,14 +575,14 @@ class TenantService: @staticmethod def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: """Create tenant""" - if ( - not FeatureService.get_system_features().is_allow_create_workspace - and not is_setup - and not is_from_dashboard - ): - from controllers.console.error import NotAllowedCreateWorkspace - - raise NotAllowedCreateWorkspace() + # if ( + # not FeatureService.get_system_features().is_allow_create_workspace + # and not is_setup + # and not is_from_dashboard + # ): + # from controllers.console.error import NotAllowedCreateWorkspace + + # raise NotAllowedCreateWorkspace() tenant = Tenant(name=name) db.session.add(tenant) @@ -619,11 +605,18 @@ class TenantService: return """Create owner tenant if not exist""" - if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: - raise WorkSpaceNotAllowedCreateError() + # if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + # raise WorkSpaceNotAllowedCreateError() + + # TODO 需要补充逻辑,根据用户的机构id,寻找是否已经创建了tenant,如果有,则将用户加入该tanent,如果没有,先创建tanent,再将用户加入tanent + # 当前将用户全部加进一个默认tanent + + name = "教育大模型应用空间" if name: - tenant = TenantService.create_tenant(name=name, is_setup=is_setup) + tenant = TenantService.get_tenant_by_name(name) + if tenant is None: + tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) TenantService.create_tenant_member(tenant, account, role="owner") @@ -835,6 +828,11 @@ class TenantService: db.session.delete(tenant) db.session.commit() + @staticmethod + def get_tenant_by_name(name: str) -> Optional[Tenant]: + """Get tenant by name""" + return db.session.query(Tenant).filter(Tenant.name == name).first() + @staticmethod def get_custom_config(tenant_id: str) -> dict: tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() diff --git a/api/services/feature_service.py b/api/services/feature_service.py index c2226c319f..1452506954 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -47,7 +47,7 @@ class FeatureModel(BaseModel): members: LimitationModel = LimitationModel(size=0, limit=1) apps: LimitationModel = LimitationModel(size=0, limit=10) vector_space: LimitationModel = LimitationModel(size=0, limit=5) - knowledge_rate_limit: int = 10 + knowledge_rate_limit: int = 100 annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) docs_processing: str = "standard" @@ -66,18 +66,18 @@ class KnowledgeRateLimitModel(BaseModel): class SystemFeatureModel(BaseModel): - sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = "" + sso_enforced_for_signin: bool = True + sso_enforced_for_signin_protocol: str = "oauth2" sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = "" + sso_enforced_for_web_protocol: str = "oauth2" enable_web_sso_switch_component: bool = False enable_marketplace: bool = False max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE enable_email_code_login: bool = False enable_email_password_login: bool = True - enable_social_oauth_login: bool = False + enable_social_oauth_login: bool = True is_allow_register: bool = False - is_allow_create_workspace: bool = False + is_allow_create_workspace: bool = True is_email_setup: bool = False license: LicenseModel = LicenseModel() @@ -100,7 +100,7 @@ class FeatureService: if dify_config.BILLING_ENABLED and tenant_id: knowledge_rate_limit.enabled = True limit_info = BillingService.get_knowledge_rate_limit(tenant_id) - knowledge_rate_limit.limit = limit_info.get("limit", 10) + knowledge_rate_limit.limit = limit_info.get("limit", 100) knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox") return knowledge_rate_limit @@ -110,10 +110,10 @@ class FeatureService: cls._fulfill_system_params_from_env(system_features) - if dify_config.ENTERPRISE_ENABLED: - system_features.enable_web_sso_switch_component = True + # if dify_config.ENTERPRISE_ENABLED: + # system_features.enable_web_sso_switch_component = True - cls._fulfill_params_from_enterprise(system_features) + # cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True @@ -124,9 +124,11 @@ class FeatureService: def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel): system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN - system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + # system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.enable_social_oauth_login = True system_features.is_allow_register = dify_config.ALLOW_REGISTER - system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + # system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + system_features.is_allow_create_workspace = True system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" @classmethod @@ -134,7 +136,7 @@ class FeatureService: features.can_replace_logo = dify_config.CAN_REPLACE_LOGO features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED - features.education.enabled = dify_config.EDUCATION_ENABLED + # features.education.enabled = dify_config.EDUCATION_ENABLED @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): @@ -181,35 +183,35 @@ class FeatureService: def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - if "sso_enforced_for_signin" in enterprise_info: - features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + # if "sso_enforced_for_signin" in enterprise_info: + # features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] - if "sso_enforced_for_signin_protocol" in enterprise_info: - features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + # if "sso_enforced_for_signin_protocol" in enterprise_info: + # features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] - if "sso_enforced_for_web" in enterprise_info: - features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + # if "sso_enforced_for_web" in enterprise_info: + # features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] - if "sso_enforced_for_web_protocol" in enterprise_info: - features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + # if "sso_enforced_for_web_protocol" in enterprise_info: + # features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] - if "enable_email_code_login" in enterprise_info: - features.enable_email_code_login = enterprise_info["enable_email_code_login"] + # if "enable_email_code_login" in enterprise_info: + # features.enable_email_code_login = enterprise_info["enable_email_code_login"] - if "enable_email_password_login" in enterprise_info: - features.enable_email_password_login = enterprise_info["enable_email_password_login"] + # if "enable_email_password_login" in enterprise_info: + # features.enable_email_password_login = enterprise_info["enable_email_password_login"] - if "is_allow_register" in enterprise_info: - features.is_allow_register = enterprise_info["is_allow_register"] + # if "is_allow_register" in enterprise_info: + # features.is_allow_register = enterprise_info["is_allow_register"] - if "is_allow_create_workspace" in enterprise_info: - features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] + # if "is_allow_create_workspace" in enterprise_info: + # features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] - if "license" in enterprise_info: - license_info = enterprise_info["license"] + # if "license" in enterprise_info: + # license_info = enterprise_info["license"] - if "status" in license_info: - features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + # if "status" in license_info: + # features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - if "expired_at" in license_info: - features.license.expired_at = license_info["expired_at"] + # if "expired_at" in license_info: + # features.license.expired_at = license_info["expired_at"]