format account service

pull/21891/head
ytqh 1 year ago
parent b8a243c5f0
commit 8a43186637

@ -8,10 +8,6 @@ from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, Optional, cast
from pydantic import BaseModel
from sqlalchemy import func
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
@ -32,6 +28,7 @@ from models.account import (
TenantStatus,
)
from models.model import DifySetup
from pydantic import BaseModel
from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
@ -51,11 +48,13 @@ from services.errors.account import (
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
from sqlalchemy import func
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task
from werkzeug.exceptions import Unauthorized
class TokenPair(BaseModel):
@ -69,12 +68,16 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
reset_password_rate_limiter = RateLimiter(
prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
prefix="email_code_account_deletion_rate_limit",
max_attempts=1,
time_window=60 * 1,
)
LOGIN_MAX_ERROR_LIMITS = 5
@ -88,9 +91,15 @@ class AccountService:
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
AccountService._get_refresh_token_key(refresh_token),
REFRESH_TOKEN_EXPIRY,
account_id,
)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id),
REFRESH_TOKEN_EXPIRY,
refresh_token,
)
@staticmethod
@ -107,12 +116,16 @@ class AccountService:
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
current_tenant = TenantAccountJoin.query.filter_by(
account_id=account.id, current=True
).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
else:
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
TenantAccountJoin.query.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
)
if not available_ta:
return None
@ -121,7 +134,9 @@ class AccountService:
available_ta.current = True
db.session.commit()
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(
minutes=10
):
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
@ -129,7 +144,9 @@ class AccountService:
@staticmethod
def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp_dt = datetime.now(UTC) + timedelta(
minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES
)
exp = int(exp_dt.timestamp())
payload = {
"user_id": account.id,
@ -142,7 +159,9 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
def authenticate(
email: str, password: str, invite_token: Optional[str] = None
) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
@ -161,7 +180,9 @@ class AccountService:
account.password = base64_password_hashed
account.password_salt = base64_salt
if account.password is None or not compare_password(password, account.password, account.password_salt):
if account.password is None or not compare_password(
password, account.password, account.password_salt
):
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value:
@ -175,7 +196,9 @@ class AccountService:
@staticmethod
def update_account_password(account, password, new_password):
"""update account password"""
if account.password and not compare_password(password, account.password, account.password_salt):
if account.password and not compare_password(
password, account.password, account.password_salt
):
raise CurrentPasswordIncorrectError("Current password is incorrect.")
# may be raised
@ -244,11 +267,18 @@ class AccountService:
@staticmethod
def create_account_in_tenant(
tenant: Tenant, email: str, name: str, interface_language: str, password: Optional[str] = None
tenant: Tenant,
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
email=email,
name=name,
interface_language=interface_language,
password=password,
)
TenantService.create_tenant_member(tenant, account, role="end_user")
return account
@ -259,7 +289,10 @@ class AccountService:
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
email=email,
name=name,
interface_language=interface_language,
password=password,
)
TenantService.create_owner_tenant_if_not_exist(account=account)
@ -267,10 +300,14 @@ class AccountService:
return account
@staticmethod
def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]:
def generate_account_deletion_verification_code(
account: Account,
) -> tuple[str, str]:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, token_type="account_deletion", additional_data={"code": code}
account=account,
token_type="account_deletion",
additional_data={"code": code},
)
return token, code
@ -278,7 +315,9 @@ class AccountService:
def send_account_deletion_verification_email(cls, account: Account, code: str):
email = account.email
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
from controllers.console.auth.error import (
EmailCodeAccountDeletionRateLimitExceededError,
)
raise EmailCodeAccountDeletionRateLimitExceededError()
@ -307,9 +346,11 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
account_integrate: Optional[AccountIntegrate] = (
AccountIntegrate.query.filter_by(
account_id=account.id, provider=provider
).first()
)
if account_integrate:
# If it exists, update the record
@ -319,14 +360,19 @@ class AccountService:
else:
# If it does not exist, create a new record
account_integrate = AccountIntegrate(
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
account_id=account.id,
provider=provider,
open_id=open_id,
encrypted_token="",
)
db.session.add(account_integrate)
db.session.commit()
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
except Exception as e:
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
logging.exception(
f"Failed to link {provider} account {open_id} to Account {account.id}"
)
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
@ -373,14 +419,20 @@ class AccountService:
@staticmethod
def logout(*, account: Account) -> None:
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
refresh_token = redis_client.get(
AccountService._get_account_refresh_token_key(account.id)
)
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
AccountService._delete_refresh_token(
refresh_token.decode("utf-8"), account.id
)
@staticmethod
def refresh_token(refresh_token: str) -> TokenPair:
# Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
account_id = redis_client.get(
AccountService._get_refresh_token_key(refresh_token)
)
if not account_id:
raise ValueError("Invalid refresh token")
@ -413,13 +465,18 @@ class AccountService:
raise ValueError("Email must be provided.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
from controllers.console.auth.error import (
PasswordResetRateLimitExceededError,
)
raise PasswordResetRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
account=account,
email=email,
token_type="reset_password",
additional_data={"code": code},
)
send_reset_password_mail_task.delay(
language=language,
@ -439,19 +496,27 @@ class AccountService:
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
language: Optional[str] = "en-US",
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
from controllers.console.auth.error import (
EmailCodeLoginRateLimitExceededError,
)
raise EmailCodeLoginRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
account=account,
email=email,
token_type="email_code_login",
additional_data={"code": code},
)
send_email_code_login_mail_task.delay(
language=language,
@ -541,7 +606,9 @@ class AccountService:
redis_client.setex(freeze_key, 60 * 60, 1)
return True
else:
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
redis_client.setex(
hour_limit_key, 60 * 10, hour_limit_count + 1
) # first time limit 10 minutes
# add hour limit count
redis_client.incr(hour_limit_key)
@ -566,7 +633,11 @@ class TenantService:
return Tenant.query.filter_by(id=tenant_id).first()
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
def create_tenant(
name: str,
is_setup: Optional[bool] = False,
is_from_dashboard: Optional[bool] = False,
) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@ -591,38 +662,53 @@ class TenantService:
):
"""Check if user have a workspace or not"""
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
TenantAccountJoin.query.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
)
if available_ta:
return
"""Create owner tenant if not exist"""
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
if (
not FeatureService.get_system_features().is_allow_create_workspace
and not is_setup
):
raise WorkSpaceNotAllowedCreateError()
if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
tenant = TenantService.create_tenant(
name=f"{account.name}'s Workspace", is_setup=is_setup
)
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
db.session.commit()
tenant_was_created.send(tenant)
@staticmethod
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
def create_tenant_member(
tenant: Tenant, account: Account, role: str = "normal"
) -> TenantAccountJoin:
"""Create tenant member"""
if role == TenantAccountJoinRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = (
db.session.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
if ta:
ta.role = role
else:
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
ta = TenantAccountJoin(
tenant_id=tenant.id, account_id=account.id, role=role
)
db.session.add(ta)
db.session.commit()
@ -634,7 +720,10 @@ class TenantService:
return (
db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.filter(
TenantAccountJoin.account_id == account.id,
Tenant.status == TenantStatus.NORMAL,
)
.all()
)
@ -645,7 +734,9 @@ class TenantService:
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=account.id
).first()
if ta:
tenant.role = ta.role
else:
@ -672,10 +763,13 @@ class TenantService:
)
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
raise AccountNotLinkTenantError(
"Tenant not found or account is not a member of the tenant."
)
else:
TenantAccountJoin.query.filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id != tenant_id,
).update({"current": False})
tenant_account_join.current = True
# Set the current tenant for the account
@ -730,18 +824,24 @@ class TenantService:
return (
db.session.query(TenantAccountJoin)
.filter(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles]),
)
.first()
is not None
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
def get_user_role(
account: Account, tenant: Tenant
) -> Optional[TenantAccountJoinRole]:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.account_id == account.id,
)
.first()
)
return join.role if join else None
@ -752,7 +852,9 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
def check_member_permission(
tenant: Tenant, operator: Account, member: Account | None, action: str
) -> None:
"""Check member permission"""
perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -766,18 +868,26 @@ class TenantService:
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
ta_operator = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=operator.id
).first()
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
def remove_member_from_tenant(
tenant: Tenant, account: Account, operator: Account
) -> None:
"""Remove member from tenant"""
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
if operator.id == account.id and TenantService.check_member_permission(
tenant, operator, account, "remove"
):
raise CannotOperateSelfError("Cannot operate self.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=account.id
).first()
if not ta:
raise MemberNotInTenantError("Member not in tenant.")
@ -785,18 +895,26 @@ class TenantService:
db.session.commit()
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
def update_member_role(
tenant: Tenant, member: Account, new_role: str, operator: Account
) -> None:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
target_member_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=member.id
).first()
if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
raise RoleAlreadyAssignedError(
"The provided role is already assigned to the member."
)
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, role="owner"
).first()
current_owner_join.role = "admin"
# Update the role of the target member
@ -806,7 +924,9 @@ class TenantService:
@staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant"""
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
if not TenantService.check_member_permission(
tenant, operator, operator, "remove"
):
raise NoPermissionError("No permission to dissolve tenant.")
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant)
@ -847,7 +967,9 @@ class RegisterService:
account.last_login_ip = ip_address
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
TenantService.create_owner_tenant_if_not_exist(
account=account, is_setup=True
)
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
db.session.add(dify_setup)
@ -891,7 +1013,10 @@ class RegisterService:
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
if (
FeatureService.get_system_features().is_allow_create_workspace
and create_workspace_required
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
@ -913,7 +1038,12 @@ class RegisterService:
@classmethod
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
cls,
tenant: Tenant,
email: str,
language: str,
role: str = "normal",
inviter: Optional[Account] = None,
) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
@ -924,14 +1054,20 @@ class RegisterService:
name = email.split("@")[0]
account = cls.register(
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
email=email,
name=name,
language=language,
status=AccountStatus.PENDING,
is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id, account_id=account.id
).first()
if not ta:
TenantService.create_tenant_member(tenant, account, role)
@ -962,7 +1098,11 @@ class RegisterService:
"workspace_id": tenant.id,
}
expiry_hours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
redis_client.setex(
cls._get_invitation_token_key(token),
expiry_hours * 60 * 60,
json.dumps(invitation_data),
)
return token
@classmethod
@ -974,7 +1114,9 @@ class RegisterService:
def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest()
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
cache_key = "member_invite_token:{}, {}:{}".format(
workspace_id, email_hash, token
)
redis_client.delete(cache_key)
else:
redis_client.delete(cls._get_invitation_token_key(token))
@ -989,7 +1131,9 @@ class RegisterService:
tenant = (
db.session.query(Tenant)
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.filter(
Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal"
)
.first()
)
@ -999,7 +1143,10 @@ class RegisterService:
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.filter(
Account.email == invitation_data["email"],
TenantAccountJoin.tenant_id == tenant.id,
)
.first()
)

Loading…
Cancel
Save