use config to setup default id and debug email code

pull/21891/head
ytqh 1 year ago
parent 0f7291316e
commit c06c5d2918

@ -12,6 +12,7 @@ from .middleware import MiddlewareConfig
from .packaging import PackagingInfo from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource from .remote_settings_sources.apollo import ApolloSettingsSource
from .school import SchoolConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +65,8 @@ class DifyConfig(
# Enterprise feature configs # Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.** # **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig, EnterpriseFeatureConfig,
# School Configs
SchoolConfig,
): ):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
# read from dotenv format config file # read from dotenv format config file

@ -17,6 +17,11 @@ class DeploymentConfig(BaseSettings):
default=False, default=False,
) )
DEBUG_EMAIL_CODE_FOR_LOGIN: str = Field(
description="Default email code for login",
default="111111",
)
EDITION: str = Field( EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED", default="SELF_HOSTED",

@ -0,0 +1,18 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class SchoolConfig(BaseSettings):
"""
Configuration for school-level features.
"""
DEFAULT_APP_ID: str = Field(
description="Default app id for school-level features.",
default="b278ba96-fa8e-48a8-b3e9-debe34468be0",
)
DEFAULT_TENANT_ID: str = Field(
description="Default tenant id for school-level features.",
default="5cd3029e-7f92-428a-a5c8-14a790c70233",
)

@ -1,19 +1,14 @@
from typing import cast from typing import cast
import flask_login # type: ignore import flask_login # type: ignore
from configs.deploy import DeploymentConfig from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.service_api_with_auth import api from controllers.service_api_with_auth import api
from controllers.service_api_with_auth.auth.error import ( from controllers.service_api_with_auth.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
)
from controllers.service_api_with_auth.error import ( from controllers.service_api_with_auth.error import (
AccountInFreezeError, AccountInFreezeError,
AccountNotFound, AccountNotFound,
EmailSendIpLimitError, EmailSendIpLimitError,
NotAllowedCreateWorkspace,
TenantNotFoundError, TenantNotFoundError,
) )
from flask import request from flask import request
@ -115,15 +110,11 @@ class EmailCodeLoginSendEmailApi(Resource):
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email( token = AccountService.send_email_code_login_email(email=args["email"], language=language)
email=args["email"], language=language
)
else: else:
raise AccountNotFound() raise AccountNotFound()
else: else:
token = AccountService.send_email_code_login_email( token = AccountService.send_email_code_login_email(account=account, language=language)
account=account, language=language
)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -172,21 +163,13 @@ class EmailCodeLoginApi(Resource):
description: Invalid token, email or code description: Invalid token, email or code
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
# TODO: ytqh add a new field for different tenant (default: Saier)
parser.add_argument(
"tenant_id",
type=str,
required=False,
default="5cd3029e-7f92-428a-a5c8-14a790c70233",
location="json",
) # TODO: ytqh move this to the config
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json") parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json") parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
tenant_id = args["tenant_id"] tenant_id = dify_config.DEFAULT_TENANT_ID
token_data = AccountService.get_email_code_login_data(args["token"]) token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None: if token_data is None:
@ -224,9 +207,7 @@ class EmailCodeLoginApi(Resource):
if connected_tenant is None or tenant not in connected_tenant: if connected_tenant is None or tenant not in connected_tenant:
TenantService.create_tenant_member(tenant, account, role="end_user") TenantService.create_tenant_member(tenant, account, role="end_user")
token_pair = AccountService.login( token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
account, ip_address=extract_remote_ip(request)
)
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}

@ -4,6 +4,7 @@ from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in # type: ignore from flask_login import user_logged_in # type: ignore
@ -58,8 +59,10 @@ def validate_app_token(view: Optional[Callable] = None):
except Exception as e: except Exception as e:
raise Unauthorized(f"Failed to extract user_id from token: {str(e)}") raise Unauthorized(f"Failed to extract user_id from token: {str(e)}")
# Get app model using hardcoded ID app_id = request.headers.get("X-App-Id")
app_id = "b278ba96-fa8e-48a8-b3e9-debe34468be0" # TODO: ytqh Replace with your actual hardcoded app ID if not app_id:
app_id = dify_config.DEFAULT_APP_ID
app_model = db.session.query(App).filter(App.id == app_id).first() app_model = db.session.query(App).filter(App.id == app_id).first()
if not app_model: if not app_model:

@ -69,9 +69,7 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService: class AccountService:
reset_password_rate_limiter = RateLimiter( reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_login_rate_limiter = RateLimiter( email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
) )
@ -117,16 +115,12 @@ class AccountService:
if account.status == AccountStatus.BANNED.value: if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.") raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by( current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
account_id=account.id, current=True
).first()
if current_tenant: if current_tenant:
account.current_tenant_id = current_tenant.tenant_id account.current_tenant_id = current_tenant.tenant_id
else: else:
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id) TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if not available_ta: if not available_ta:
return None return None
@ -135,9 +129,7 @@ class AccountService:
available_ta.current = True available_ta.current = True
db.session.commit() db.session.commit()
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta( if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
minutes=10
):
account.last_active_at = datetime.now(UTC).replace(tzinfo=None) account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
@ -145,9 +137,7 @@ class AccountService:
@staticmethod @staticmethod
def get_account_jwt_token(account: Account) -> str: def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta( exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES
)
exp = int(exp_dt.timestamp()) exp = int(exp_dt.timestamp())
payload = { payload = {
"user_id": account.id, "user_id": account.id,
@ -160,9 +150,7 @@ class AccountService:
return token return token
@staticmethod @staticmethod
def authenticate( def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
email: str, password: str, invite_token: Optional[str] = None
) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = Account.query.filter_by(email=email).first() account = Account.query.filter_by(email=email).first()
@ -181,9 +169,7 @@ class AccountService:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
if account.password is None or not compare_password( if account.password is None or not compare_password(password, account.password, account.password_salt):
password, account.password, account.password_salt
):
raise AccountPasswordError("Invalid email or password.") raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
@ -197,9 +183,7 @@ class AccountService:
@staticmethod @staticmethod
def update_account_password(account, password, new_password): def update_account_password(account, password, new_password):
"""update account password""" """update account password"""
if account.password and not compare_password( if account.password and not compare_password(password, account.password, account.password_salt):
password, account.password, account.password_salt
):
raise CurrentPasswordIncorrectError("Current password is incorrect.") raise CurrentPasswordIncorrectError("Current password is incorrect.")
# may be raised # may be raised
@ -316,9 +300,7 @@ class AccountService:
def send_account_deletion_verification_email(cls, account: Account, code: str): def send_account_deletion_verification_email(cls, account: Account, code: str):
email = account.email email = account.email
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import ( from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
EmailCodeAccountDeletionRateLimitExceededError,
)
raise EmailCodeAccountDeletionRateLimitExceededError() raise EmailCodeAccountDeletionRateLimitExceededError()
@ -347,11 +329,9 @@ class AccountService:
"""Link account integrate""" """Link account integrate"""
try: try:
# Query whether there is an existing binding record for the same provider # Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = ( account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
AccountIntegrate.query.filter_by(
account_id=account.id, provider=provider account_id=account.id, provider=provider
).first() ).first()
)
if account_integrate: if account_integrate:
# If it exists, update the record # If it exists, update the record
@ -371,9 +351,7 @@ class AccountService:
db.session.commit() db.session.commit()
logging.info(f"Account {account.id} linked {provider} account {open_id}.") logging.info(f"Account {account.id} linked {provider} account {open_id}.")
except Exception as e: except Exception as e:
logging.exception( logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
f"Failed to link {provider} account {open_id} to Account {account.id}"
)
raise LinkAccountIntegrateError("Failed to link account.") from e raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod @staticmethod
@ -420,20 +398,14 @@ class AccountService:
@staticmethod @staticmethod
def logout(*, account: Account) -> None: def logout(*, account: Account) -> None:
refresh_token = redis_client.get( refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
AccountService._get_account_refresh_token_key(account.id)
)
if refresh_token: if refresh_token:
AccountService._delete_refresh_token( AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
refresh_token.decode("utf-8"), account.id
)
@staticmethod @staticmethod
def refresh_token(refresh_token: str) -> TokenPair: def refresh_token(refresh_token: str) -> TokenPair:
# Verify the refresh token # Verify the refresh token
account_id = redis_client.get( account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
AccountService._get_refresh_token_key(refresh_token)
)
if not account_id: if not account_id:
raise ValueError("Invalid refresh token") raise ValueError("Invalid refresh token")
@ -466,9 +438,7 @@ class AccountService:
raise ValueError("Email must be provided.") raise ValueError("Email must be provided.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email): if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import ( from controllers.console.auth.error import PasswordResetRateLimitExceededError
PasswordResetRateLimitExceededError,
)
raise PasswordResetRateLimitExceededError() raise PasswordResetRateLimitExceededError()
@ -505,19 +475,13 @@ class AccountService:
email = account.email if account else email email = account.email if account else email
if email is None: if email is None:
raise ValueError("Email must be provided.") raise ValueError("Email must be provided.")
if ( if cls.email_code_login_rate_limiter.is_rate_limited(email) and not DeploymentConfig().DEBUG:
cls.email_code_login_rate_limiter.is_rate_limited(email) from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
and not DeploymentConfig().DEBUG
):
from controllers.console.auth.error import (
EmailCodeLoginRateLimitExceededError,
)
raise EmailCodeLoginRateLimitExceededError() raise EmailCodeLoginRateLimitExceededError()
# if debug mode, force set code to 111111
if DeploymentConfig().DEBUG: if DeploymentConfig().DEBUG:
code = "111111" # TODO: ytqh move this to config code = dify_config.DEBUG_EMAIL_CODE_FOR_LOGIN
else: else:
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
@ -615,9 +579,7 @@ class AccountService:
redis_client.setex(freeze_key, 60 * 60, 1) redis_client.setex(freeze_key, 60 * 60, 1)
return True return True
else: else:
redis_client.setex( redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
hour_limit_key, 60 * 10, hour_limit_count + 1
) # first time limit 10 minutes
# add hour limit count # add hour limit count
redis_client.incr(hour_limit_key) redis_client.incr(hour_limit_key)
@ -671,53 +633,38 @@ class TenantService:
): ):
"""Check if user have a workspace or not""" """Check if user have a workspace or not"""
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id) TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if available_ta: if available_ta:
return return
"""Create owner tenant if not exist""" """Create owner tenant if not exist"""
if ( if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
not FeatureService.get_system_features().is_allow_create_workspace
and not is_setup
):
raise WorkSpaceNotAllowedCreateError() raise WorkSpaceNotAllowedCreateError()
if name: if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup) tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else: else:
tenant = TenantService.create_tenant( tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
name=f"{account.name}'s Workspace", is_setup=is_setup
)
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
db.session.commit() db.session.commit()
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
@staticmethod @staticmethod
def create_tenant_member( def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
tenant: Tenant, account: Account, role: str = "normal"
) -> TenantAccountJoin:
"""Create tenant member""" """Create tenant member"""
if role == TenantAccountJoinRole.OWNER.value: if role == TenantAccountJoinRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
logging.error(f"Tenant {tenant.id} has already an owner.") logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.") raise Exception("Tenant already has an owner.")
ta = ( ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
db.session.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
if ta: if ta:
ta.role = role ta.role = role
else: else:
ta = TenantAccountJoin( ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
tenant_id=tenant.id, account_id=account.id, role=role
)
db.session.add(ta) db.session.add(ta)
db.session.commit() db.session.commit()
@ -743,9 +690,7 @@ class TenantService:
if not tenant: if not tenant:
raise TenantNotFoundError("Tenant not found.") raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by( ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
tenant_id=tenant.id, account_id=account.id
).first()
if ta: if ta:
tenant.role = ta.role tenant.role = ta.role
else: else:
@ -772,9 +717,7 @@ class TenantService:
) )
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError( raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
"Tenant not found or account is not a member of the tenant."
)
else: else:
TenantAccountJoin.query.filter( TenantAccountJoin.query.filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.account_id == account.id,
@ -841,9 +784,7 @@ class TenantService:
) )
@staticmethod @staticmethod
def get_user_role( def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
account: Account, tenant: Tenant
) -> Optional[TenantAccountJoinRole]:
"""Get the role of the current account for a given tenant""" """Get the role of the current account for a given tenant"""
join = ( join = (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
@ -861,9 +802,7 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar()) return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod @staticmethod
def check_member_permission( def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
tenant: Tenant, operator: Account, member: Account | None, action: str
) -> None:
"""Check member permission""" """Check member permission"""
perms = { perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -877,26 +816,18 @@ class TenantService:
if operator.id == member.id: if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by( ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
tenant_id=tenant.id, account_id=operator.id
).first()
if not ta_operator or ta_operator.role not in perms[action]: if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.") raise NoPermissionError(f"No permission to {action} member.")
@staticmethod @staticmethod
def remove_member_from_tenant( def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
tenant: Tenant, account: Account, operator: Account
) -> None:
"""Remove member from tenant""" """Remove member from tenant"""
if operator.id == account.id and TenantService.check_member_permission( if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
tenant, operator, account, "remove"
):
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta = TenantAccountJoin.query.filter_by( ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
tenant_id=tenant.id, account_id=account.id
).first()
if not ta: if not ta:
raise MemberNotInTenantError("Member not in tenant.") raise MemberNotInTenantError("Member not in tenant.")
@ -904,26 +835,18 @@ class TenantService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def update_member_role( def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
tenant: Tenant, member: Account, new_role: str, operator: Account
) -> None:
"""Update member role""" """Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update") TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by( target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
tenant_id=tenant.id, account_id=member.id
).first()
if target_member_join.role == new_role: if target_member_join.role == new_role:
raise RoleAlreadyAssignedError( raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
"The provided role is already assigned to the member."
)
if new_role == "owner": if new_role == "owner":
# Find the current owner and change their role to 'admin' # Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by( current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
tenant_id=tenant.id, role="owner"
).first()
current_owner_join.role = "admin" current_owner_join.role = "admin"
# Update the role of the target member # Update the role of the target member
@ -933,9 +856,7 @@ class TenantService:
@staticmethod @staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None: def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant""" """Dissolve tenant"""
if not TenantService.check_member_permission( if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
tenant, operator, operator, "remove"
):
raise NoPermissionError("No permission to dissolve tenant.") raise NoPermissionError("No permission to dissolve tenant.")
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant) db.session.delete(tenant)
@ -976,9 +897,7 @@ class RegisterService:
account.last_login_ip = ip_address account.last_login_ip = ip_address
account.initialized_at = datetime.now(UTC).replace(tzinfo=None) account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
TenantService.create_owner_tenant_if_not_exist( TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
account=account, is_setup=True
)
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
db.session.add(dify_setup) db.session.add(dify_setup)
@ -1022,10 +941,7 @@ class RegisterService:
if open_id is not None and provider is not None: if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account) AccountService.link_account_integrate(provider, open_id, account)
if ( if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
FeatureService.get_system_features().is_allow_create_workspace
and create_workspace_required
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
@ -1074,9 +990,7 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id) TenantService.switch_tenant(account, tenant.id)
else: else:
TenantService.check_member_permission(tenant, inviter, account, "add") TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by( ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
tenant_id=tenant.id, account_id=account.id
).first()
if not ta: if not ta:
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)
@ -1123,9 +1037,7 @@ class RegisterService:
def revoke_token(cls, workspace_id: str, email: str, token: str): def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email: if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest() email_hash = sha256(email.encode()).hexdigest()
cache_key = "member_invite_token:{}, {}:{}".format( cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
workspace_id, email_hash, token
)
redis_client.delete(cache_key) redis_client.delete(cache_key)
else: else:
redis_client.delete(cls._get_invitation_token_key(token)) redis_client.delete(cls._get_invitation_token_key(token))
@ -1140,9 +1052,7 @@ class RegisterService:
tenant = ( tenant = (
db.session.query(Tenant) db.session.query(Tenant)
.filter( .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal"
)
.first() .first()
) )

Loading…
Cancel
Save