format account service

pull/21891/head
ytqh 1 year ago
parent b8a243c5f0
commit 8a43186637

@ -8,10 +8,6 @@ from datetime import UTC, datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional, cast from typing import Any, Optional, cast
from pydantic import BaseModel
from sqlalchemy import func
from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
from constants.languages import language_timezone_mapping, languages from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
@ -32,6 +28,7 @@ from models.account import (
TenantStatus, TenantStatus,
) )
from models.model import DifySetup from models.model import DifySetup
from pydantic import BaseModel
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import ( from services.errors.account import (
AccountAlreadyInTenantError, AccountAlreadyInTenantError,
@ -51,11 +48,13 @@ from services.errors.account import (
) )
from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService from services.feature_service import FeatureService
from sqlalchemy import func
from tasks.delete_account_task import delete_account_task from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code from tasks.mail_account_deletion_task import send_account_deletion_verification_code
from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task
from werkzeug.exceptions import Unauthorized
class TokenPair(BaseModel): class TokenPair(BaseModel):
@ -69,12 +68,16 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService: class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) reset_password_rate_limiter = RateLimiter(
prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_login_rate_limiter = RateLimiter( email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
) )
email_code_account_deletion_rate_limiter = RateLimiter( email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 prefix="email_code_account_deletion_rate_limit",
max_attempts=1,
time_window=60 * 1,
) )
LOGIN_MAX_ERROR_LIMITS = 5 LOGIN_MAX_ERROR_LIMITS = 5
@ -88,9 +91,15 @@ class AccountService:
@staticmethod @staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None: def _store_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex( redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token AccountService._get_refresh_token_key(refresh_token),
REFRESH_TOKEN_EXPIRY,
account_id,
)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id),
REFRESH_TOKEN_EXPIRY,
refresh_token,
) )
@staticmethod @staticmethod
@ -107,12 +116,16 @@ class AccountService:
if account.status == AccountStatus.BANNED.value: if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.") raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() current_tenant = TenantAccountJoin.query.filter_by(
account_id=account.id, current=True
).first()
if current_tenant: if current_tenant:
account.current_tenant_id = current_tenant.tenant_id account.current_tenant_id = current_tenant.tenant_id
else: else:
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() TenantAccountJoin.query.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if not available_ta: if not available_ta:
return None return None
@ -121,7 +134,9 @@ class AccountService:
available_ta.current = True available_ta.current = True
db.session.commit() db.session.commit()
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(
minutes=10
):
account.last_active_at = datetime.now(UTC).replace(tzinfo=None) account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
@ -129,7 +144,9 @@ class AccountService:
@staticmethod @staticmethod
def get_account_jwt_token(account: Account) -> str: def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) exp_dt = datetime.now(UTC) + timedelta(
minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES
)
exp = int(exp_dt.timestamp()) exp = int(exp_dt.timestamp())
payload = { payload = {
"user_id": account.id, "user_id": account.id,
@ -142,7 +159,9 @@ class AccountService:
return token return token
@staticmethod @staticmethod
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: def authenticate(
email: str, password: str, invite_token: Optional[str] = None
) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = Account.query.filter_by(email=email).first() account = Account.query.filter_by(email=email).first()
@ -161,7 +180,9 @@ class AccountService:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
if account.password is None or not compare_password(password, account.password, account.password_salt): if account.password is None or not compare_password(
password, account.password, account.password_salt
):
raise AccountPasswordError("Invalid email or password.") raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
@ -175,7 +196,9 @@ class AccountService:
@staticmethod @staticmethod
def update_account_password(account, password, new_password): def update_account_password(account, password, new_password):
"""update account password""" """update account password"""
if account.password and not compare_password(password, account.password, account.password_salt): if account.password and not compare_password(
password, account.password, account.password_salt
):
raise CurrentPasswordIncorrectError("Current password is incorrect.") raise CurrentPasswordIncorrectError("Current password is incorrect.")
# may be raised # may be raised
@ -244,11 +267,18 @@ class AccountService:
@staticmethod @staticmethod
def create_account_in_tenant( def create_account_in_tenant(
tenant: Tenant, email: str, name: str, interface_language: str, password: Optional[str] = None tenant: Tenant,
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
) -> Account: ) -> Account:
"""create account""" """create account"""
account = AccountService.create_account( account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password email=email,
name=name,
interface_language=interface_language,
password=password,
) )
TenantService.create_tenant_member(tenant, account, role="end_user") TenantService.create_tenant_member(tenant, account, role="end_user")
return account return account
@ -259,7 +289,10 @@ class AccountService:
) -> Account: ) -> Account:
"""create account""" """create account"""
account = AccountService.create_account( account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password email=email,
name=name,
interface_language=interface_language,
password=password,
) )
TenantService.create_owner_tenant_if_not_exist(account=account) TenantService.create_owner_tenant_if_not_exist(account=account)
@ -267,10 +300,14 @@ class AccountService:
return account return account
@staticmethod @staticmethod
def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]: def generate_account_deletion_verification_code(
account: Account,
) -> tuple[str, str]:
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token( token = TokenManager.generate_token(
account=account, token_type="account_deletion", additional_data={"code": code} account=account,
token_type="account_deletion",
additional_data={"code": code},
) )
return token, code return token, code
@ -278,7 +315,9 @@ class AccountService:
def send_account_deletion_verification_email(cls, account: Account, code: str): def send_account_deletion_verification_email(cls, account: Account, code: str):
email = account.email email = account.email
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError from controllers.console.auth.error import (
EmailCodeAccountDeletionRateLimitExceededError,
)
raise EmailCodeAccountDeletionRateLimitExceededError() raise EmailCodeAccountDeletionRateLimitExceededError()
@ -307,9 +346,11 @@ class AccountService:
"""Link account integrate""" """Link account integrate"""
try: try:
# Query whether there is an existing binding record for the same provider # Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( account_integrate: Optional[AccountIntegrate] = (
account_id=account.id, provider=provider AccountIntegrate.query.filter_by(
).first() account_id=account.id, provider=provider
).first()
)
if account_integrate: if account_integrate:
# If it exists, update the record # If it exists, update the record
@ -319,14 +360,19 @@ class AccountService:
else: else:
# If it does not exist, create a new record # If it does not exist, create a new record
account_integrate = AccountIntegrate( account_integrate = AccountIntegrate(
account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" account_id=account.id,
provider=provider,
open_id=open_id,
encrypted_token="",
) )
db.session.add(account_integrate) db.session.add(account_integrate)
db.session.commit() db.session.commit()
logging.info(f"Account {account.id} linked {provider} account {open_id}.") logging.info(f"Account {account.id} linked {provider} account {open_id}.")
except Exception as e: except Exception as e:
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") logging.exception(
f"Failed to link {provider} account {open_id} to Account {account.id}"
)
raise LinkAccountIntegrateError("Failed to link account.") from e raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod @staticmethod
@ -373,14 +419,20 @@ class AccountService:
@staticmethod @staticmethod
def logout(*, account: Account) -> None: def logout(*, account: Account) -> None:
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) refresh_token = redis_client.get(
AccountService._get_account_refresh_token_key(account.id)
)
if refresh_token: if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) AccountService._delete_refresh_token(
refresh_token.decode("utf-8"), account.id
)
@staticmethod @staticmethod
def refresh_token(refresh_token: str) -> TokenPair: def refresh_token(refresh_token: str) -> TokenPair:
# Verify the refresh token # Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) account_id = redis_client.get(
AccountService._get_refresh_token_key(refresh_token)
)
if not account_id: if not account_id:
raise ValueError("Invalid refresh token") raise ValueError("Invalid refresh token")
@ -413,13 +465,18 @@ class AccountService:
raise ValueError("Email must be provided.") raise ValueError("Email must be provided.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email): if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError from controllers.console.auth.error import (
PasswordResetRateLimitExceededError,
)
raise PasswordResetRateLimitExceededError() raise PasswordResetRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token( token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code} account=account,
email=email,
token_type="reset_password",
additional_data={"code": code},
) )
send_reset_password_mail_task.delay( send_reset_password_mail_task.delay(
language=language, language=language,
@ -439,19 +496,27 @@ class AccountService:
@classmethod @classmethod
def send_email_code_login_email( def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" cls,
account: Optional[Account] = None,
email: Optional[str] = None,
language: Optional[str] = "en-US",
): ):
email = account.email if account else email email = account.email if account else email
if email is None: if email is None:
raise ValueError("Email must be provided.") raise ValueError("Email must be provided.")
if cls.email_code_login_rate_limiter.is_rate_limited(email): if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError from controllers.console.auth.error import (
EmailCodeLoginRateLimitExceededError,
)
raise EmailCodeLoginRateLimitExceededError() raise EmailCodeLoginRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token( token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code} account=account,
email=email,
token_type="email_code_login",
additional_data={"code": code},
) )
send_email_code_login_mail_task.delay( send_email_code_login_mail_task.delay(
language=language, language=language,
@ -541,7 +606,9 @@ class AccountService:
redis_client.setex(freeze_key, 60 * 60, 1) redis_client.setex(freeze_key, 60 * 60, 1)
return True return True
else: else:
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes redis_client.setex(
hour_limit_key, 60 * 10, hour_limit_count + 1
) # first time limit 10 minutes
# add hour limit count # add hour limit count
redis_client.incr(hour_limit_key) redis_client.incr(hour_limit_key)
@ -564,9 +631,13 @@ class TenantService:
@staticmethod @staticmethod
def get_tenant_by_id(tenant_id: str) -> Tenant: def get_tenant_by_id(tenant_id: str) -> Tenant:
return Tenant.query.filter_by(id=tenant_id).first() return Tenant.query.filter_by(id=tenant_id).first()
@staticmethod @staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: def create_tenant(
name: str,
is_setup: Optional[bool] = False,
is_from_dashboard: Optional[bool] = False,
) -> Tenant:
"""Create tenant""" """Create tenant"""
if ( if (
not FeatureService.get_system_features().is_allow_create_workspace not FeatureService.get_system_features().is_allow_create_workspace
@ -591,38 +662,53 @@ class TenantService:
): ):
"""Check if user have a workspace or not""" """Check if user have a workspace or not"""
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() TenantAccountJoin.query.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if available_ta: if available_ta:
return return
"""Create owner tenant if not exist""" """Create owner tenant if not exist"""
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: if (
not FeatureService.get_system_features().is_allow_create_workspace
and not is_setup
):
raise WorkSpaceNotAllowedCreateError() raise WorkSpaceNotAllowedCreateError()
if name: if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup) tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else: else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) tenant = TenantService.create_tenant(
name=f"{account.name}'s Workspace", is_setup=is_setup
)
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
db.session.commit() db.session.commit()
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
@staticmethod @staticmethod
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: def create_tenant_member(
tenant: Tenant, account: Account, role: str = "normal"
) -> TenantAccountJoin:
"""Create tenant member""" """Create tenant member"""
if role == TenantAccountJoinRole.OWNER.value: if role == TenantAccountJoinRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
logging.error(f"Tenant {tenant.id} has already an owner.") logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.") raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = (
db.session.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
if ta: if ta:
ta.role = role ta.role = role
else: else:
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) ta = TenantAccountJoin(
tenant_id=tenant.id, account_id=account.id, role=role
)
db.session.add(ta) db.session.add(ta)
db.session.commit() db.session.commit()
@ -634,7 +720,10 @@ class TenantService:
return ( return (
db.session.query(Tenant) db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) .filter(
TenantAccountJoin.account_id == account.id,
Tenant.status == TenantStatus.NORMAL,
)
.all() .all()
) )
@ -645,7 +734,9 @@ class TenantService:
if not tenant: if not tenant:
raise TenantNotFoundError("Tenant not found.") raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=account.id
).first()
if ta: if ta:
tenant.role = ta.role tenant.role = ta.role
else: else:
@ -672,10 +763,13 @@ class TenantService:
) )
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") raise AccountNotLinkTenantError(
"Tenant not found or account is not a member of the tenant."
)
else: else:
TenantAccountJoin.query.filter( TenantAccountJoin.query.filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id != tenant_id,
).update({"current": False}) ).update({"current": False})
tenant_account_join.current = True tenant_account_join.current = True
# Set the current tenant for the account # Set the current tenant for the account
@ -730,18 +824,24 @@ class TenantService:
return ( return (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.filter( .filter(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles]),
) )
.first() .first()
is not None is not None
) )
@staticmethod @staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: def get_user_role(
account: Account, tenant: Tenant
) -> Optional[TenantAccountJoinRole]:
"""Get the role of the current account for a given tenant""" """Get the role of the current account for a given tenant"""
join = ( join = (
db.session.query(TenantAccountJoin) db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.account_id == account.id,
)
.first() .first()
) )
return join.role if join else None return join.role if join else None
@ -752,7 +852,9 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar()) return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod @staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: def check_member_permission(
tenant: Tenant, operator: Account, member: Account | None, action: str
) -> None:
"""Check member permission""" """Check member permission"""
perms = { perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -766,18 +868,26 @@ class TenantService:
if operator.id == member.id: if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() ta_operator = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=operator.id
).first()
if not ta_operator or ta_operator.role not in perms[action]: if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.") raise NoPermissionError(f"No permission to {action} member.")
@staticmethod @staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: def remove_member_from_tenant(
tenant: Tenant, account: Account, operator: Account
) -> None:
"""Remove member from tenant""" """Remove member from tenant"""
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): if operator.id == account.id and TenantService.check_member_permission(
tenant, operator, account, "remove"
):
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=account.id
).first()
if not ta: if not ta:
raise MemberNotInTenantError("Member not in tenant.") raise MemberNotInTenantError("Member not in tenant.")
@ -785,18 +895,26 @@ class TenantService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: def update_member_role(
tenant: Tenant, member: Account, new_role: str, operator: Account
) -> None:
"""Update member role""" """Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update") TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() target_member_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=member.id
).first()
if target_member_join.role == new_role: if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") raise RoleAlreadyAssignedError(
"The provided role is already assigned to the member."
)
if new_role == "owner": if new_role == "owner":
# Find the current owner and change their role to 'admin' # Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() current_owner_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, role="owner"
).first()
current_owner_join.role = "admin" current_owner_join.role = "admin"
# Update the role of the target member # Update the role of the target member
@ -806,7 +924,9 @@ class TenantService:
@staticmethod @staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None: def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant""" """Dissolve tenant"""
if not TenantService.check_member_permission(tenant, operator, operator, "remove"): if not TenantService.check_member_permission(
tenant, operator, operator, "remove"
):
raise NoPermissionError("No permission to dissolve tenant.") raise NoPermissionError("No permission to dissolve tenant.")
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant) db.session.delete(tenant)
@ -847,7 +967,9 @@ class RegisterService:
account.last_login_ip = ip_address account.last_login_ip = ip_address
account.initialized_at = datetime.now(UTC).replace(tzinfo=None) account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) TenantService.create_owner_tenant_if_not_exist(
account=account, is_setup=True
)
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
db.session.add(dify_setup) db.session.add(dify_setup)
@ -891,7 +1013,10 @@ class RegisterService:
if open_id is not None and provider is not None: if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account) AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: if (
FeatureService.get_system_features().is_allow_create_workspace
and create_workspace_required
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
@ -913,7 +1038,12 @@ class RegisterService:
@classmethod @classmethod
def invite_new_member( def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None cls,
tenant: Tenant,
email: str,
language: str,
role: str = "normal",
inviter: Optional[Account] = None,
) -> str: ) -> str:
"""Invite new member""" """Invite new member"""
account = Account.query.filter_by(email=email).first() account = Account.query.filter_by(email=email).first()
@ -924,14 +1054,20 @@ class RegisterService:
name = email.split("@")[0] name = email.split("@")[0]
account = cls.register( account = cls.register(
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True email=email,
name=name,
language=language,
status=AccountStatus.PENDING,
is_setup=True,
) )
# Create new tenant member for invited tenant # Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)
TenantService.switch_tenant(account, tenant.id) TenantService.switch_tenant(account, tenant.id)
else: else:
TenantService.check_member_permission(tenant, inviter, account, "add") TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=account.id
).first()
if not ta: if not ta:
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)
@ -962,7 +1098,11 @@ class RegisterService:
"workspace_id": tenant.id, "workspace_id": tenant.id,
} }
expiry_hours = dify_config.INVITE_EXPIRY_HOURS expiry_hours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) redis_client.setex(
cls._get_invitation_token_key(token),
expiry_hours * 60 * 60,
json.dumps(invitation_data),
)
return token return token
@classmethod @classmethod
@ -974,7 +1114,9 @@ class RegisterService:
def revoke_token(cls, workspace_id: str, email: str, token: str): def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email: if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest() email_hash = sha256(email.encode()).hexdigest()
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) cache_key = "member_invite_token:{}, {}:{}".format(
workspace_id, email_hash, token
)
redis_client.delete(cache_key) redis_client.delete(cache_key)
else: else:
redis_client.delete(cls._get_invitation_token_key(token)) redis_client.delete(cls._get_invitation_token_key(token))
@ -989,7 +1131,9 @@ class RegisterService:
tenant = ( tenant = (
db.session.query(Tenant) db.session.query(Tenant)
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") .filter(
Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal"
)
.first() .first()
) )
@ -999,7 +1143,10 @@ class RegisterService:
tenant_account = ( tenant_account = (
db.session.query(Account, TenantAccountJoin.role) db.session.query(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) .filter(
Account.email == invitation_data["email"],
TenantAccountJoin.tenant_id == tenant.id,
)
.first() .first()
) )

Loading…
Cancel
Save