|
|
|
@ -70,7 +70,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AccountService:
|
|
|
|
class AccountService:
|
|
|
|
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
|
|
|
reset_password_rate_limiter = RateLimiter(
|
|
|
|
|
|
|
|
prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1
|
|
|
|
|
|
|
|
)
|
|
|
|
email_code_login_rate_limiter = RateLimiter(
|
|
|
|
email_code_login_rate_limiter = RateLimiter(
|
|
|
|
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
|
|
|
|
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@ -120,12 +122,16 @@ class AccountService:
|
|
|
|
if account.status == AccountStatus.BANNED.value:
|
|
|
|
if account.status == AccountStatus.BANNED.value:
|
|
|
|
raise Unauthorized("Account is banned.")
|
|
|
|
raise Unauthorized("Account is banned.")
|
|
|
|
|
|
|
|
|
|
|
|
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
|
|
|
current_tenant = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
account_id=account.id, current=True
|
|
|
|
|
|
|
|
).first()
|
|
|
|
if current_tenant:
|
|
|
|
if current_tenant:
|
|
|
|
account.current_tenant_id = current_tenant.tenant_id
|
|
|
|
account.current_tenant_id = current_tenant.tenant_id
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
available_ta = (
|
|
|
|
available_ta = (
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id)
|
|
|
|
|
|
|
|
.order_by(TenantAccountJoin.id.asc())
|
|
|
|
|
|
|
|
.first()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if not available_ta:
|
|
|
|
if not available_ta:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
@ -134,7 +140,9 @@ class AccountService:
|
|
|
|
available_ta.current = True
|
|
|
|
available_ta.current = True
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
|
|
|
|
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(
|
|
|
|
|
|
|
|
minutes=10
|
|
|
|
|
|
|
|
):
|
|
|
|
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
@ -142,7 +150,9 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_account_jwt_token(account: Account) -> str:
|
|
|
|
def get_account_jwt_token(account: Account) -> str:
|
|
|
|
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
exp_dt = datetime.now(UTC) + timedelta(
|
|
|
|
|
|
|
|
minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES
|
|
|
|
|
|
|
|
)
|
|
|
|
exp = int(exp_dt.timestamp())
|
|
|
|
exp = int(exp_dt.timestamp())
|
|
|
|
payload = {
|
|
|
|
payload = {
|
|
|
|
"user_id": account.id,
|
|
|
|
"user_id": account.id,
|
|
|
|
@ -155,7 +165,9 @@ class AccountService:
|
|
|
|
return token
|
|
|
|
return token
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
|
|
|
|
def authenticate(
|
|
|
|
|
|
|
|
email: str, password: str, invite_token: Optional[str] = None
|
|
|
|
|
|
|
|
) -> Account:
|
|
|
|
"""authenticate account with email and password"""
|
|
|
|
"""authenticate account with email and password"""
|
|
|
|
|
|
|
|
|
|
|
|
account = db.session.query(Account).filter_by(email=email).first()
|
|
|
|
account = db.session.query(Account).filter_by(email=email).first()
|
|
|
|
@ -174,7 +186,9 @@ class AccountService:
|
|
|
|
account.password = base64_password_hashed
|
|
|
|
account.password = base64_password_hashed
|
|
|
|
account.password_salt = base64_salt
|
|
|
|
account.password_salt = base64_salt
|
|
|
|
|
|
|
|
|
|
|
|
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
|
|
|
if account.password is None or not compare_password(
|
|
|
|
|
|
|
|
password, account.password, account.password_salt
|
|
|
|
|
|
|
|
):
|
|
|
|
raise AccountPasswordError("Invalid email or password.")
|
|
|
|
raise AccountPasswordError("Invalid email or password.")
|
|
|
|
|
|
|
|
|
|
|
|
if account.status == AccountStatus.PENDING.value:
|
|
|
|
if account.status == AccountStatus.PENDING.value:
|
|
|
|
@ -188,7 +202,9 @@ class AccountService:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def update_account_password(account, password, new_password):
|
|
|
|
def update_account_password(account, password, new_password):
|
|
|
|
"""update account password"""
|
|
|
|
"""update account password"""
|
|
|
|
if account.password and not compare_password(password, account.password, account.password_salt):
|
|
|
|
if account.password and not compare_password(
|
|
|
|
|
|
|
|
password, account.password, account.password_salt
|
|
|
|
|
|
|
|
):
|
|
|
|
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
|
|
|
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
|
|
|
|
|
|
|
|
|
|
|
# may be raised
|
|
|
|
# may be raised
|
|
|
|
@ -305,7 +321,9 @@ class AccountService:
|
|
|
|
def send_account_deletion_verification_email(cls, account: Account, code: str):
|
|
|
|
def send_account_deletion_verification_email(cls, account: Account, code: str):
|
|
|
|
email = account.email
|
|
|
|
email = account.email
|
|
|
|
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
|
|
|
|
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
|
|
|
|
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
|
|
|
|
from controllers.console.auth.error import (
|
|
|
|
|
|
|
|
EmailCodeAccountDeletionRateLimitExceededError,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
raise EmailCodeAccountDeletionRateLimitExceededError()
|
|
|
|
raise EmailCodeAccountDeletionRateLimitExceededError()
|
|
|
|
|
|
|
|
|
|
|
|
@ -334,9 +352,11 @@ class AccountService:
|
|
|
|
"""Link account integrate"""
|
|
|
|
"""Link account integrate"""
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
# Query whether there is an existing binding record for the same provider
|
|
|
|
# Query whether there is an existing binding record for the same provider
|
|
|
|
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
|
|
|
|
account_integrate: Optional[AccountIntegrate] = (
|
|
|
|
account_id=account.id, provider=provider
|
|
|
|
AccountIntegrate.query.filter_by(
|
|
|
|
).first()
|
|
|
|
account_id=account.id, provider=provider
|
|
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if account_integrate:
|
|
|
|
if account_integrate:
|
|
|
|
# If it exists, update the record
|
|
|
|
# If it exists, update the record
|
|
|
|
@ -356,7 +376,9 @@ class AccountService:
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
|
|
|
|
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
|
|
|
|
logging.exception(
|
|
|
|
|
|
|
|
f"Failed to link {provider} account {open_id} to Account {account.id}"
|
|
|
|
|
|
|
|
)
|
|
|
|
raise LinkAccountIntegrateError("Failed to link account.") from e
|
|
|
|
raise LinkAccountIntegrateError("Failed to link account.") from e
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@ -403,14 +425,20 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def logout(*, account: Account) -> None:
|
|
|
|
def logout(*, account: Account) -> None:
|
|
|
|
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
|
|
|
refresh_token = redis_client.get(
|
|
|
|
|
|
|
|
AccountService._get_account_refresh_token_key(account.id)
|
|
|
|
|
|
|
|
)
|
|
|
|
if refresh_token:
|
|
|
|
if refresh_token:
|
|
|
|
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
|
|
|
AccountService._delete_refresh_token(
|
|
|
|
|
|
|
|
refresh_token.decode("utf-8"), account.id
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def refresh_token(refresh_token: str) -> TokenPair:
|
|
|
|
def refresh_token(refresh_token: str) -> TokenPair:
|
|
|
|
# Verify the refresh token
|
|
|
|
# Verify the refresh token
|
|
|
|
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
|
|
|
account_id = redis_client.get(
|
|
|
|
|
|
|
|
AccountService._get_refresh_token_key(refresh_token)
|
|
|
|
|
|
|
|
)
|
|
|
|
if not account_id:
|
|
|
|
if not account_id:
|
|
|
|
raise ValueError("Invalid refresh token")
|
|
|
|
raise ValueError("Invalid refresh token")
|
|
|
|
|
|
|
|
|
|
|
|
@ -443,7 +471,9 @@ class AccountService:
|
|
|
|
raise ValueError("Email must be provided.")
|
|
|
|
raise ValueError("Email must be provided.")
|
|
|
|
|
|
|
|
|
|
|
|
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
|
|
|
|
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
|
|
|
|
from controllers.console.auth.error import PasswordResetRateLimitExceededError
|
|
|
|
from controllers.console.auth.error import (
|
|
|
|
|
|
|
|
PasswordResetRateLimitExceededError,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
raise PasswordResetRateLimitExceededError()
|
|
|
|
raise PasswordResetRateLimitExceededError()
|
|
|
|
|
|
|
|
|
|
|
|
@ -469,7 +499,10 @@ class AccountService:
|
|
|
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
|
|
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
|
|
|
additional_data["code"] = code
|
|
|
|
additional_data["code"] = code
|
|
|
|
token = TokenManager.generate_token(
|
|
|
|
token = TokenManager.generate_token(
|
|
|
|
account=account, email=email, token_type="reset_password", additional_data=additional_data
|
|
|
|
account=account,
|
|
|
|
|
|
|
|
email=email,
|
|
|
|
|
|
|
|
token_type="reset_password",
|
|
|
|
|
|
|
|
additional_data=additional_data,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return code, token
|
|
|
|
return code, token
|
|
|
|
|
|
|
|
|
|
|
|
@ -492,10 +525,14 @@ class AccountService:
|
|
|
|
if email is None:
|
|
|
|
if email is None:
|
|
|
|
raise ValueError("Email must be provided.")
|
|
|
|
raise ValueError("Email must be provided.")
|
|
|
|
|
|
|
|
|
|
|
|
if dify_config.DEBUG_ORG_EMAIL_DOMAIN and email.endswith(dify_config.DEBUG_ORG_EMAIL_DOMAIN):
|
|
|
|
if dify_config.DEBUG_ORG_EMAIL_DOMAIN and email.endswith(
|
|
|
|
|
|
|
|
dify_config.DEBUG_ORG_EMAIL_DOMAIN
|
|
|
|
|
|
|
|
):
|
|
|
|
code = dify_config.DEBUG_CODE_FOR_LOGIN
|
|
|
|
code = dify_config.DEBUG_CODE_FOR_LOGIN
|
|
|
|
elif cls.email_code_login_rate_limiter.is_rate_limited(email):
|
|
|
|
elif cls.email_code_login_rate_limiter.is_rate_limited(email):
|
|
|
|
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
|
|
|
|
from controllers.console.auth.error import (
|
|
|
|
|
|
|
|
EmailCodeLoginRateLimitExceededError,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
raise EmailCodeLoginRateLimitExceededError()
|
|
|
|
raise EmailCodeLoginRateLimitExceededError()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -622,7 +659,9 @@ class AccountService:
|
|
|
|
redis_client.setex(freeze_key, 60 * 60, 1)
|
|
|
|
redis_client.setex(freeze_key, 60 * 60, 1)
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
|
|
|
|
redis_client.setex(
|
|
|
|
|
|
|
|
hour_limit_key, 60 * 10, hour_limit_count + 1
|
|
|
|
|
|
|
|
) # first time limit 10 minutes
|
|
|
|
|
|
|
|
|
|
|
|
# add hour limit count
|
|
|
|
# add hour limit count
|
|
|
|
redis_client.incr(hour_limit_key)
|
|
|
|
redis_client.incr(hour_limit_key)
|
|
|
|
@ -647,11 +686,7 @@ class AccountService:
|
|
|
|
Raises Unauthorized if account is banned.
|
|
|
|
Raises Unauthorized if account is banned.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Query directly with phone number first
|
|
|
|
# Query directly with phone number first
|
|
|
|
admin_account = (
|
|
|
|
admin_account = db.session.query(Account).filter(Account.phone == phone).first()
|
|
|
|
db.session.query(Account)
|
|
|
|
|
|
|
|
.filter(Account.phone == phone)
|
|
|
|
|
|
|
|
.first()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not admin_account:
|
|
|
|
if not admin_account:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
@ -662,7 +697,9 @@ class AccountService:
|
|
|
|
organization_id = admin_account.current_organization_id
|
|
|
|
organization_id = admin_account.current_organization_id
|
|
|
|
|
|
|
|
|
|
|
|
if not organization_id:
|
|
|
|
if not organization_id:
|
|
|
|
logging.warning(f"Account {admin_account.id} is not a member of any organization.")
|
|
|
|
logging.warning(
|
|
|
|
|
|
|
|
f"Account {admin_account.id} is not a member of any organization."
|
|
|
|
|
|
|
|
)
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# If organization_id is provided, check if account is an admin member of that organization
|
|
|
|
# If organization_id is provided, check if account is an admin member of that organization
|
|
|
|
@ -673,26 +710,28 @@ class AccountService:
|
|
|
|
.filter(
|
|
|
|
.filter(
|
|
|
|
OrganizationMember.organization_id == organization_id,
|
|
|
|
OrganizationMember.organization_id == organization_id,
|
|
|
|
OrganizationMember.account_id == admin_account.id,
|
|
|
|
OrganizationMember.account_id == admin_account.id,
|
|
|
|
OrganizationMember.role == OrganizationRole.ADMIN
|
|
|
|
OrganizationMember.role == OrganizationRole.ADMIN,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
.first()
|
|
|
|
.first()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not org_member:
|
|
|
|
if not org_member:
|
|
|
|
logging.warning(f"Account {admin_account.id} is not a member of any organization.")
|
|
|
|
logging.warning(
|
|
|
|
|
|
|
|
f"Account {admin_account.id} is not a member of any organization."
|
|
|
|
|
|
|
|
)
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
return admin_account
|
|
|
|
return admin_account
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def is_phone_send_ip_limit(cls, ip_address: str) -> bool:
|
|
|
|
def is_login_attempt_ip_limit(cls, ip_address: str) -> bool:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Check if IP has reached the limit for sending phone verification codes.
|
|
|
|
Check if IP has reached the limit for sending phone verification codes.
|
|
|
|
Similar to is_email_send_ip_limit but for phone verification.
|
|
|
|
Similar to is_email_send_ip_limit but for phone verification.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
minute_key = f"phone_send_ip_limit_minute:{ip_address}"
|
|
|
|
minute_key = f"login_attempt_ip_limit_minute:{ip_address}"
|
|
|
|
freeze_key = f"phone_send_ip_limit_freeze:{ip_address}"
|
|
|
|
freeze_key = f"login_attempt_ip_limit_freeze:{ip_address}"
|
|
|
|
hour_limit_key = f"phone_send_ip_limit_hour:{ip_address}"
|
|
|
|
hour_limit_key = f"login_attempt_ip_limit_hour:{ip_address}"
|
|
|
|
|
|
|
|
|
|
|
|
# check ip is frozen
|
|
|
|
# check ip is frozen
|
|
|
|
if redis_client.get(freeze_key):
|
|
|
|
if redis_client.get(freeze_key):
|
|
|
|
@ -705,7 +744,9 @@ class AccountService:
|
|
|
|
current_minute_count = int(current_minute_count)
|
|
|
|
current_minute_count = int(current_minute_count)
|
|
|
|
|
|
|
|
|
|
|
|
# check current hour count
|
|
|
|
# check current hour count
|
|
|
|
if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: # Use same limit as email
|
|
|
|
if (
|
|
|
|
|
|
|
|
current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE
|
|
|
|
|
|
|
|
): # Use same limit as email
|
|
|
|
hour_limit_count = redis_client.get(hour_limit_key)
|
|
|
|
hour_limit_count = redis_client.get(hour_limit_key)
|
|
|
|
if hour_limit_count is None:
|
|
|
|
if hour_limit_count is None:
|
|
|
|
hour_limit_count = 0
|
|
|
|
hour_limit_count = 0
|
|
|
|
@ -715,7 +756,9 @@ class AccountService:
|
|
|
|
redis_client.setex(freeze_key, 60 * 60, 1)
|
|
|
|
redis_client.setex(freeze_key, 60 * 60, 1)
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
|
|
|
|
redis_client.setex(
|
|
|
|
|
|
|
|
hour_limit_key, 60 * 10, hour_limit_count + 1
|
|
|
|
|
|
|
|
) # first time limit 10 minutes
|
|
|
|
|
|
|
|
|
|
|
|
# add hour limit count
|
|
|
|
# add hour limit count
|
|
|
|
redis_client.incr(hour_limit_key)
|
|
|
|
redis_client.incr(hour_limit_key)
|
|
|
|
@ -769,6 +812,34 @@ class AccountService:
|
|
|
|
"""Revoke phone code login token"""
|
|
|
|
"""Revoke phone code login token"""
|
|
|
|
TokenManager.revoke_token(token, "phone_code_login")
|
|
|
|
TokenManager.revoke_token(token, "phone_code_login")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def get_admin_through_login_id(cls, login_id: str):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get admin account through login ID (either email or phone number).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
login_id: The email or phone number to search for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns None if no admin account with this ID exists.
|
|
|
|
|
|
|
|
Raises Unauthorized if account is banned.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
account = (
|
|
|
|
|
|
|
|
db.session.query(Account)
|
|
|
|
|
|
|
|
.filter((Account.email == login_id) | (Account.phone == login_id))
|
|
|
|
|
|
|
|
.first()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not account:
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if account.status == AccountStatus.BANNED.value:
|
|
|
|
|
|
|
|
raise Unauthorized("Account is banned.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not account.is_org_admin():
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return account
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TenantService:
|
|
|
|
class TenantService:
|
|
|
|
|
|
|
|
|
|
|
|
@ -806,38 +877,53 @@ class TenantService:
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""Check if user have a workspace or not"""
|
|
|
|
"""Check if user have a workspace or not"""
|
|
|
|
available_ta = (
|
|
|
|
available_ta = (
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id)
|
|
|
|
|
|
|
|
.order_by(TenantAccountJoin.id.asc())
|
|
|
|
|
|
|
|
.first()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if available_ta:
|
|
|
|
if available_ta:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
"""Create owner tenant if not exist"""
|
|
|
|
"""Create owner tenant if not exist"""
|
|
|
|
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
|
|
|
|
if (
|
|
|
|
|
|
|
|
not FeatureService.get_system_features().is_allow_create_workspace
|
|
|
|
|
|
|
|
and not is_setup
|
|
|
|
|
|
|
|
):
|
|
|
|
raise WorkSpaceNotAllowedCreateError()
|
|
|
|
raise WorkSpaceNotAllowedCreateError()
|
|
|
|
|
|
|
|
|
|
|
|
if name:
|
|
|
|
if name:
|
|
|
|
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
|
|
|
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
|
|
|
|
tenant = TenantService.create_tenant(
|
|
|
|
|
|
|
|
name=f"{account.name}'s Workspace", is_setup=is_setup
|
|
|
|
|
|
|
|
)
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
|
account.current_tenant = tenant
|
|
|
|
account.current_tenant = tenant
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
tenant_was_created.send(tenant)
|
|
|
|
tenant_was_created.send(tenant)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
|
|
|
|
def create_tenant_member(
|
|
|
|
|
|
|
|
tenant: Tenant, account: Account, role: str = "normal"
|
|
|
|
|
|
|
|
) -> TenantAccountJoin:
|
|
|
|
"""Create tenant member"""
|
|
|
|
"""Create tenant member"""
|
|
|
|
if role == TenantAccountRole.OWNER.value:
|
|
|
|
if role == TenantAccountRole.OWNER.value:
|
|
|
|
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
|
|
|
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
|
|
|
logging.error(f"Tenant {tenant.id} has already an owner.")
|
|
|
|
logging.error(f"Tenant {tenant.id} has already an owner.")
|
|
|
|
raise Exception("Tenant already has an owner.")
|
|
|
|
raise Exception("Tenant already has an owner.")
|
|
|
|
|
|
|
|
|
|
|
|
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
ta = (
|
|
|
|
|
|
|
|
db.session.query(TenantAccountJoin)
|
|
|
|
|
|
|
|
.filter_by(tenant_id=tenant.id, account_id=account.id)
|
|
|
|
|
|
|
|
.first()
|
|
|
|
|
|
|
|
)
|
|
|
|
if ta:
|
|
|
|
if ta:
|
|
|
|
ta.role = role
|
|
|
|
ta.role = role
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
|
|
|
|
ta = TenantAccountJoin(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id, role=role
|
|
|
|
|
|
|
|
)
|
|
|
|
db.session.add(ta)
|
|
|
|
db.session.add(ta)
|
|
|
|
|
|
|
|
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
@ -863,7 +949,9 @@ class TenantService:
|
|
|
|
if not tenant:
|
|
|
|
if not tenant:
|
|
|
|
raise TenantNotFoundError("Tenant not found.")
|
|
|
|
raise TenantNotFoundError("Tenant not found.")
|
|
|
|
|
|
|
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
ta = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id
|
|
|
|
|
|
|
|
).first()
|
|
|
|
if ta:
|
|
|
|
if ta:
|
|
|
|
tenant.role = ta.role
|
|
|
|
tenant.role = ta.role
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -890,7 +978,9 @@ class TenantService:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not tenant_account_join:
|
|
|
|
if not tenant_account_join:
|
|
|
|
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
|
|
|
raise AccountNotLinkTenantError(
|
|
|
|
|
|
|
|
"Tenant not found or account is not a member of the tenant."
|
|
|
|
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
TenantAccountJoin.query.filter(
|
|
|
|
TenantAccountJoin.query.filter(
|
|
|
|
TenantAccountJoin.account_id == account.id,
|
|
|
|
TenantAccountJoin.account_id == account.id,
|
|
|
|
@ -975,7 +1065,9 @@ class TenantService:
|
|
|
|
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
|
|
|
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
|
|
|
|
def check_member_permission(
|
|
|
|
|
|
|
|
tenant: Tenant, operator: Account, member: Account | None, action: str
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
"""Check member permission"""
|
|
|
|
"""Check member permission"""
|
|
|
|
perms = {
|
|
|
|
perms = {
|
|
|
|
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
|
|
|
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
|
|
|
@ -989,20 +1081,26 @@ class TenantService:
|
|
|
|
if operator.id == member.id:
|
|
|
|
if operator.id == member.id:
|
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
|
|
|
|
|
|
|
|
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
|
|
|
ta_operator = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, account_id=operator.id
|
|
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
|
|
if not ta_operator or ta_operator.role not in perms[action]:
|
|
|
|
if not ta_operator or ta_operator.role not in perms[action]:
|
|
|
|
raise NoPermissionError(f"No permission to {action} member.")
|
|
|
|
raise NoPermissionError(f"No permission to {action} member.")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
|
|
|
def remove_member_from_tenant(
|
|
|
|
|
|
|
|
tenant: Tenant, account: Account, operator: Account
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
"""Remove member from tenant"""
|
|
|
|
"""Remove member from tenant"""
|
|
|
|
if operator.id == account.id:
|
|
|
|
if operator.id == account.id:
|
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
|
|
|
|
|
|
|
|
TenantService.check_member_permission(tenant, operator, account, "remove")
|
|
|
|
TenantService.check_member_permission(tenant, operator, account, "remove")
|
|
|
|
|
|
|
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
ta = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id
|
|
|
|
|
|
|
|
).first()
|
|
|
|
if not ta:
|
|
|
|
if not ta:
|
|
|
|
raise MemberNotInTenantError("Member not in tenant.")
|
|
|
|
raise MemberNotInTenantError("Member not in tenant.")
|
|
|
|
|
|
|
|
|
|
|
|
@ -1010,18 +1108,26 @@ class TenantService:
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
|
|
|
def update_member_role(
|
|
|
|
|
|
|
|
tenant: Tenant, member: Account, new_role: str, operator: Account
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
"""Update member role"""
|
|
|
|
"""Update member role"""
|
|
|
|
TenantService.check_member_permission(tenant, operator, member, "update")
|
|
|
|
TenantService.check_member_permission(tenant, operator, member, "update")
|
|
|
|
|
|
|
|
|
|
|
|
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
|
|
|
target_member_join = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, account_id=member.id
|
|
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
|
|
if target_member_join.role == new_role:
|
|
|
|
if target_member_join.role == new_role:
|
|
|
|
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
|
|
|
raise RoleAlreadyAssignedError(
|
|
|
|
|
|
|
|
"The provided role is already assigned to the member."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if new_role == "owner":
|
|
|
|
if new_role == "owner":
|
|
|
|
# Find the current owner and change their role to 'admin'
|
|
|
|
# Find the current owner and change their role to 'admin'
|
|
|
|
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
|
|
|
current_owner_join = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, role="owner"
|
|
|
|
|
|
|
|
).first()
|
|
|
|
current_owner_join.role = "admin"
|
|
|
|
current_owner_join.role = "admin"
|
|
|
|
|
|
|
|
|
|
|
|
# Update the role of the target member
|
|
|
|
# Update the role of the target member
|
|
|
|
@ -1031,7 +1137,9 @@ class TenantService:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
|
|
|
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
|
|
|
"""Dissolve tenant"""
|
|
|
|
"""Dissolve tenant"""
|
|
|
|
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
|
|
|
|
if not TenantService.check_member_permission(
|
|
|
|
|
|
|
|
tenant, operator, operator, "remove"
|
|
|
|
|
|
|
|
):
|
|
|
|
raise NoPermissionError("No permission to dissolve tenant.")
|
|
|
|
raise NoPermissionError("No permission to dissolve tenant.")
|
|
|
|
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
|
|
|
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
|
|
|
db.session.delete(tenant)
|
|
|
|
db.session.delete(tenant)
|
|
|
|
@ -1116,7 +1224,10 @@ class RegisterService:
|
|
|
|
if open_id is not None and provider is not None:
|
|
|
|
if open_id is not None and provider is not None:
|
|
|
|
AccountService.link_account_integrate(provider, open_id, account)
|
|
|
|
AccountService.link_account_integrate(provider, open_id, account)
|
|
|
|
|
|
|
|
|
|
|
|
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
|
|
|
|
if (
|
|
|
|
|
|
|
|
FeatureService.get_system_features().is_allow_create_workspace
|
|
|
|
|
|
|
|
and create_workspace_required
|
|
|
|
|
|
|
|
):
|
|
|
|
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
|
|
|
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
|
account.current_tenant = tenant
|
|
|
|
account.current_tenant = tenant
|
|
|
|
@ -1140,7 +1251,12 @@ class RegisterService:
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def invite_new_member(
|
|
|
|
def invite_new_member(
|
|
|
|
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
|
|
|
cls,
|
|
|
|
|
|
|
|
tenant: Tenant,
|
|
|
|
|
|
|
|
email: str,
|
|
|
|
|
|
|
|
language: str,
|
|
|
|
|
|
|
|
role: str = "normal",
|
|
|
|
|
|
|
|
inviter: Account | None = None,
|
|
|
|
) -> str:
|
|
|
|
) -> str:
|
|
|
|
if not inviter:
|
|
|
|
if not inviter:
|
|
|
|
raise ValueError("Inviter is required")
|
|
|
|
raise ValueError("Inviter is required")
|
|
|
|
@ -1165,7 +1281,9 @@ class RegisterService:
|
|
|
|
TenantService.switch_tenant(account, tenant.id)
|
|
|
|
TenantService.switch_tenant(account, tenant.id)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
TenantService.check_member_permission(tenant, inviter, account, "add")
|
|
|
|
TenantService.check_member_permission(tenant, inviter, account, "add")
|
|
|
|
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
ta = TenantAccountJoin.query.filter_by(
|
|
|
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id
|
|
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
|
|
if not ta:
|
|
|
|
if not ta:
|
|
|
|
TenantService.create_tenant_member(tenant, account, role)
|
|
|
|
TenantService.create_tenant_member(tenant, account, role)
|
|
|
|
@ -1212,7 +1330,9 @@ class RegisterService:
|
|
|
|
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
|
|
|
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
|
|
|
if workspace_id and email:
|
|
|
|
if workspace_id and email:
|
|
|
|
email_hash = sha256(email.encode()).hexdigest()
|
|
|
|
email_hash = sha256(email.encode()).hexdigest()
|
|
|
|
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
|
|
|
|
cache_key = "member_invite_token:{}, {}:{}".format(
|
|
|
|
|
|
|
|
workspace_id, email_hash, token
|
|
|
|
|
|
|
|
)
|
|
|
|
redis_client.delete(cache_key)
|
|
|
|
redis_client.delete(cache_key)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
redis_client.delete(cls._get_invitation_token_key(token))
|
|
|
|
redis_client.delete(cls._get_invitation_token_key(token))
|
|
|
|
@ -1227,7 +1347,9 @@ class RegisterService:
|
|
|
|
|
|
|
|
|
|
|
|
tenant = (
|
|
|
|
tenant = (
|
|
|
|
db.session.query(Tenant)
|
|
|
|
db.session.query(Tenant)
|
|
|
|
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
|
|
|
.filter(
|
|
|
|
|
|
|
|
Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal"
|
|
|
|
|
|
|
|
)
|
|
|
|
.first()
|
|
|
|
.first()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|