|
|
|
|
@ -8,10 +8,6 @@ from datetime import UTC, datetime, timedelta
|
|
|
|
|
from hashlib import sha256
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from sqlalchemy import func
|
|
|
|
|
from werkzeug.exceptions import Unauthorized
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from constants.languages import language_timezone_mapping, languages
|
|
|
|
|
from events.tenant_event import tenant_was_created
|
|
|
|
|
@ -32,6 +28,7 @@ from models.account import (
|
|
|
|
|
TenantStatus,
|
|
|
|
|
)
|
|
|
|
|
from models.model import DifySetup
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from services.billing_service import BillingService
|
|
|
|
|
from services.errors.account import (
|
|
|
|
|
AccountAlreadyInTenantError,
|
|
|
|
|
@ -51,11 +48,13 @@ from services.errors.account import (
|
|
|
|
|
)
|
|
|
|
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
|
|
|
|
from services.feature_service import FeatureService
|
|
|
|
|
from sqlalchemy import func
|
|
|
|
|
from tasks.delete_account_task import delete_account_task
|
|
|
|
|
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
|
|
|
|
|
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
|
|
|
|
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
|
|
|
|
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
|
|
|
|
from werkzeug.exceptions import Unauthorized
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenPair(BaseModel):
|
|
|
|
|
@ -69,12 +68,16 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AccountService:
|
|
|
|
|
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
|
|
|
|
reset_password_rate_limiter = RateLimiter(
|
|
|
|
|
prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1
|
|
|
|
|
)
|
|
|
|
|
email_code_login_rate_limiter = RateLimiter(
|
|
|
|
|
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
|
|
|
|
|
)
|
|
|
|
|
email_code_account_deletion_rate_limiter = RateLimiter(
|
|
|
|
|
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
|
|
|
|
|
prefix="email_code_account_deletion_rate_limit",
|
|
|
|
|
max_attempts=1,
|
|
|
|
|
time_window=60 * 1,
|
|
|
|
|
)
|
|
|
|
|
LOGIN_MAX_ERROR_LIMITS = 5
|
|
|
|
|
|
|
|
|
|
@ -88,9 +91,15 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
|
|
|
|
|
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
|
|
|
|
|
redis_client.setex(
|
|
|
|
|
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
|
|
|
|
|
AccountService._get_refresh_token_key(refresh_token),
|
|
|
|
|
REFRESH_TOKEN_EXPIRY,
|
|
|
|
|
account_id,
|
|
|
|
|
)
|
|
|
|
|
redis_client.setex(
|
|
|
|
|
AccountService._get_account_refresh_token_key(account_id),
|
|
|
|
|
REFRESH_TOKEN_EXPIRY,
|
|
|
|
|
refresh_token,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@ -107,12 +116,16 @@ class AccountService:
|
|
|
|
|
if account.status == AccountStatus.BANNED.value:
|
|
|
|
|
raise Unauthorized("Account is banned.")
|
|
|
|
|
|
|
|
|
|
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
|
|
|
|
current_tenant = TenantAccountJoin.query.filter_by(
|
|
|
|
|
account_id=account.id, current=True
|
|
|
|
|
).first()
|
|
|
|
|
if current_tenant:
|
|
|
|
|
account.current_tenant_id = current_tenant.tenant_id
|
|
|
|
|
else:
|
|
|
|
|
available_ta = (
|
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id)
|
|
|
|
|
.order_by(TenantAccountJoin.id.asc())
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if not available_ta:
|
|
|
|
|
return None
|
|
|
|
|
@ -121,7 +134,9 @@ class AccountService:
|
|
|
|
|
available_ta.current = True
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
|
|
|
|
|
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(
|
|
|
|
|
minutes=10
|
|
|
|
|
):
|
|
|
|
|
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
@ -129,7 +144,9 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_account_jwt_token(account: Account) -> str:
|
|
|
|
|
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
|
exp_dt = datetime.now(UTC) + timedelta(
|
|
|
|
|
minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES
|
|
|
|
|
)
|
|
|
|
|
exp = int(exp_dt.timestamp())
|
|
|
|
|
payload = {
|
|
|
|
|
"user_id": account.id,
|
|
|
|
|
@ -142,7 +159,9 @@ class AccountService:
|
|
|
|
|
return token
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
|
|
|
|
|
def authenticate(
|
|
|
|
|
email: str, password: str, invite_token: Optional[str] = None
|
|
|
|
|
) -> Account:
|
|
|
|
|
"""authenticate account with email and password"""
|
|
|
|
|
|
|
|
|
|
account = Account.query.filter_by(email=email).first()
|
|
|
|
|
@ -161,7 +180,9 @@ class AccountService:
|
|
|
|
|
account.password = base64_password_hashed
|
|
|
|
|
account.password_salt = base64_salt
|
|
|
|
|
|
|
|
|
|
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
|
|
|
|
if account.password is None or not compare_password(
|
|
|
|
|
password, account.password, account.password_salt
|
|
|
|
|
):
|
|
|
|
|
raise AccountPasswordError("Invalid email or password.")
|
|
|
|
|
|
|
|
|
|
if account.status == AccountStatus.PENDING.value:
|
|
|
|
|
@ -175,7 +196,9 @@ class AccountService:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_account_password(account, password, new_password):
|
|
|
|
|
"""update account password"""
|
|
|
|
|
if account.password and not compare_password(password, account.password, account.password_salt):
|
|
|
|
|
if account.password and not compare_password(
|
|
|
|
|
password, account.password, account.password_salt
|
|
|
|
|
):
|
|
|
|
|
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
|
|
|
|
|
|
|
|
|
# may be raised
|
|
|
|
|
@ -244,11 +267,18 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def create_account_in_tenant(
|
|
|
|
|
tenant: Tenant, email: str, name: str, interface_language: str, password: Optional[str] = None
|
|
|
|
|
tenant: Tenant,
|
|
|
|
|
email: str,
|
|
|
|
|
name: str,
|
|
|
|
|
interface_language: str,
|
|
|
|
|
password: Optional[str] = None,
|
|
|
|
|
) -> Account:
|
|
|
|
|
"""create account"""
|
|
|
|
|
account = AccountService.create_account(
|
|
|
|
|
email=email, name=name, interface_language=interface_language, password=password
|
|
|
|
|
email=email,
|
|
|
|
|
name=name,
|
|
|
|
|
interface_language=interface_language,
|
|
|
|
|
password=password,
|
|
|
|
|
)
|
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="end_user")
|
|
|
|
|
return account
|
|
|
|
|
@ -259,7 +289,10 @@ class AccountService:
|
|
|
|
|
) -> Account:
|
|
|
|
|
"""create account"""
|
|
|
|
|
account = AccountService.create_account(
|
|
|
|
|
email=email, name=name, interface_language=interface_language, password=password
|
|
|
|
|
email=email,
|
|
|
|
|
name=name,
|
|
|
|
|
interface_language=interface_language,
|
|
|
|
|
password=password,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
TenantService.create_owner_tenant_if_not_exist(account=account)
|
|
|
|
|
@ -267,10 +300,14 @@ class AccountService:
|
|
|
|
|
return account
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]:
|
|
|
|
|
def generate_account_deletion_verification_code(
|
|
|
|
|
account: Account,
|
|
|
|
|
) -> tuple[str, str]:
|
|
|
|
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
|
|
|
|
token = TokenManager.generate_token(
|
|
|
|
|
account=account, token_type="account_deletion", additional_data={"code": code}
|
|
|
|
|
account=account,
|
|
|
|
|
token_type="account_deletion",
|
|
|
|
|
additional_data={"code": code},
|
|
|
|
|
)
|
|
|
|
|
return token, code
|
|
|
|
|
|
|
|
|
|
@ -278,7 +315,9 @@ class AccountService:
|
|
|
|
|
def send_account_deletion_verification_email(cls, account: Account, code: str):
|
|
|
|
|
email = account.email
|
|
|
|
|
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
|
|
|
|
|
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
|
|
|
|
|
from controllers.console.auth.error import (
|
|
|
|
|
EmailCodeAccountDeletionRateLimitExceededError,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
raise EmailCodeAccountDeletionRateLimitExceededError()
|
|
|
|
|
|
|
|
|
|
@ -307,9 +346,11 @@ class AccountService:
|
|
|
|
|
"""Link account integrate"""
|
|
|
|
|
try:
|
|
|
|
|
# Query whether there is an existing binding record for the same provider
|
|
|
|
|
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
|
|
|
|
|
account_id=account.id, provider=provider
|
|
|
|
|
).first()
|
|
|
|
|
account_integrate: Optional[AccountIntegrate] = (
|
|
|
|
|
AccountIntegrate.query.filter_by(
|
|
|
|
|
account_id=account.id, provider=provider
|
|
|
|
|
).first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if account_integrate:
|
|
|
|
|
# If it exists, update the record
|
|
|
|
|
@ -319,14 +360,19 @@ class AccountService:
|
|
|
|
|
else:
|
|
|
|
|
# If it does not exist, create a new record
|
|
|
|
|
account_integrate = AccountIntegrate(
|
|
|
|
|
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
|
|
|
|
|
account_id=account.id,
|
|
|
|
|
provider=provider,
|
|
|
|
|
open_id=open_id,
|
|
|
|
|
encrypted_token="",
|
|
|
|
|
)
|
|
|
|
|
db.session.add(account_integrate)
|
|
|
|
|
|
|
|
|
|
db.session.commit()
|
|
|
|
|
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
|
|
|
|
|
logging.exception(
|
|
|
|
|
f"Failed to link {provider} account {open_id} to Account {account.id}"
|
|
|
|
|
)
|
|
|
|
|
raise LinkAccountIntegrateError("Failed to link account.") from e
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@ -373,14 +419,20 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def logout(*, account: Account) -> None:
|
|
|
|
|
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
|
|
|
|
refresh_token = redis_client.get(
|
|
|
|
|
AccountService._get_account_refresh_token_key(account.id)
|
|
|
|
|
)
|
|
|
|
|
if refresh_token:
|
|
|
|
|
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
|
|
|
|
AccountService._delete_refresh_token(
|
|
|
|
|
refresh_token.decode("utf-8"), account.id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def refresh_token(refresh_token: str) -> TokenPair:
|
|
|
|
|
# Verify the refresh token
|
|
|
|
|
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
|
|
|
|
account_id = redis_client.get(
|
|
|
|
|
AccountService._get_refresh_token_key(refresh_token)
|
|
|
|
|
)
|
|
|
|
|
if not account_id:
|
|
|
|
|
raise ValueError("Invalid refresh token")
|
|
|
|
|
|
|
|
|
|
@ -413,13 +465,18 @@ class AccountService:
|
|
|
|
|
raise ValueError("Email must be provided.")
|
|
|
|
|
|
|
|
|
|
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
|
|
|
|
|
from controllers.console.auth.error import PasswordResetRateLimitExceededError
|
|
|
|
|
from controllers.console.auth.error import (
|
|
|
|
|
PasswordResetRateLimitExceededError,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
raise PasswordResetRateLimitExceededError()
|
|
|
|
|
|
|
|
|
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
|
|
|
|
token = TokenManager.generate_token(
|
|
|
|
|
account=account, email=email, token_type="reset_password", additional_data={"code": code}
|
|
|
|
|
account=account,
|
|
|
|
|
email=email,
|
|
|
|
|
token_type="reset_password",
|
|
|
|
|
additional_data={"code": code},
|
|
|
|
|
)
|
|
|
|
|
send_reset_password_mail_task.delay(
|
|
|
|
|
language=language,
|
|
|
|
|
@ -439,19 +496,27 @@ class AccountService:
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def send_email_code_login_email(
|
|
|
|
|
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
|
|
|
|
|
cls,
|
|
|
|
|
account: Optional[Account] = None,
|
|
|
|
|
email: Optional[str] = None,
|
|
|
|
|
language: Optional[str] = "en-US",
|
|
|
|
|
):
|
|
|
|
|
email = account.email if account else email
|
|
|
|
|
if email is None:
|
|
|
|
|
raise ValueError("Email must be provided.")
|
|
|
|
|
if cls.email_code_login_rate_limiter.is_rate_limited(email):
|
|
|
|
|
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
|
|
|
|
|
from controllers.console.auth.error import (
|
|
|
|
|
EmailCodeLoginRateLimitExceededError,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
raise EmailCodeLoginRateLimitExceededError()
|
|
|
|
|
|
|
|
|
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
|
|
|
|
token = TokenManager.generate_token(
|
|
|
|
|
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
|
|
|
|
|
account=account,
|
|
|
|
|
email=email,
|
|
|
|
|
token_type="email_code_login",
|
|
|
|
|
additional_data={"code": code},
|
|
|
|
|
)
|
|
|
|
|
send_email_code_login_mail_task.delay(
|
|
|
|
|
language=language,
|
|
|
|
|
@ -541,7 +606,9 @@ class AccountService:
|
|
|
|
|
redis_client.setex(freeze_key, 60 * 60, 1)
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
|
|
|
|
|
redis_client.setex(
|
|
|
|
|
hour_limit_key, 60 * 10, hour_limit_count + 1
|
|
|
|
|
) # first time limit 10 minutes
|
|
|
|
|
|
|
|
|
|
# add hour limit count
|
|
|
|
|
redis_client.incr(hour_limit_key)
|
|
|
|
|
@ -564,9 +631,13 @@ class TenantService:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_tenant_by_id(tenant_id: str) -> Tenant:
|
|
|
|
|
return Tenant.query.filter_by(id=tenant_id).first()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
|
|
|
|
|
def create_tenant(
|
|
|
|
|
name: str,
|
|
|
|
|
is_setup: Optional[bool] = False,
|
|
|
|
|
is_from_dashboard: Optional[bool] = False,
|
|
|
|
|
) -> Tenant:
|
|
|
|
|
"""Create tenant"""
|
|
|
|
|
if (
|
|
|
|
|
not FeatureService.get_system_features().is_allow_create_workspace
|
|
|
|
|
@ -591,38 +662,53 @@ class TenantService:
|
|
|
|
|
):
|
|
|
|
|
"""Check if user have a workspace or not"""
|
|
|
|
|
available_ta = (
|
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
|
|
|
|
TenantAccountJoin.query.filter_by(account_id=account.id)
|
|
|
|
|
.order_by(TenantAccountJoin.id.asc())
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if available_ta:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
"""Create owner tenant if not exist"""
|
|
|
|
|
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
|
|
|
|
|
if (
|
|
|
|
|
not FeatureService.get_system_features().is_allow_create_workspace
|
|
|
|
|
and not is_setup
|
|
|
|
|
):
|
|
|
|
|
raise WorkSpaceNotAllowedCreateError()
|
|
|
|
|
|
|
|
|
|
if name:
|
|
|
|
|
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
|
|
|
|
else:
|
|
|
|
|
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
|
|
|
|
|
tenant = TenantService.create_tenant(
|
|
|
|
|
name=f"{account.name}'s Workspace", is_setup=is_setup
|
|
|
|
|
)
|
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
|
|
account.current_tenant = tenant
|
|
|
|
|
db.session.commit()
|
|
|
|
|
tenant_was_created.send(tenant)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
|
|
|
|
|
def create_tenant_member(
|
|
|
|
|
tenant: Tenant, account: Account, role: str = "normal"
|
|
|
|
|
) -> TenantAccountJoin:
|
|
|
|
|
"""Create tenant member"""
|
|
|
|
|
if role == TenantAccountJoinRole.OWNER.value:
|
|
|
|
|
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
|
|
|
|
|
logging.error(f"Tenant {tenant.id} has already an owner.")
|
|
|
|
|
raise Exception("Tenant already has an owner.")
|
|
|
|
|
|
|
|
|
|
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
|
ta = (
|
|
|
|
|
db.session.query(TenantAccountJoin)
|
|
|
|
|
.filter_by(tenant_id=tenant.id, account_id=account.id)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if ta:
|
|
|
|
|
ta.role = role
|
|
|
|
|
else:
|
|
|
|
|
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
|
|
|
|
|
ta = TenantAccountJoin(
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id, role=role
|
|
|
|
|
)
|
|
|
|
|
db.session.add(ta)
|
|
|
|
|
|
|
|
|
|
db.session.commit()
|
|
|
|
|
@ -634,7 +720,10 @@ class TenantService:
|
|
|
|
|
return (
|
|
|
|
|
db.session.query(Tenant)
|
|
|
|
|
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
|
|
|
|
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
|
|
|
|
.filter(
|
|
|
|
|
TenantAccountJoin.account_id == account.id,
|
|
|
|
|
Tenant.status == TenantStatus.NORMAL,
|
|
|
|
|
)
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -645,7 +734,9 @@ class TenantService:
|
|
|
|
|
if not tenant:
|
|
|
|
|
raise TenantNotFoundError("Tenant not found.")
|
|
|
|
|
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id
|
|
|
|
|
).first()
|
|
|
|
|
if ta:
|
|
|
|
|
tenant.role = ta.role
|
|
|
|
|
else:
|
|
|
|
|
@ -672,10 +763,13 @@ class TenantService:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not tenant_account_join:
|
|
|
|
|
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
|
|
|
|
raise AccountNotLinkTenantError(
|
|
|
|
|
"Tenant not found or account is not a member of the tenant."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
TenantAccountJoin.query.filter(
|
|
|
|
|
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
|
|
|
|
TenantAccountJoin.account_id == account.id,
|
|
|
|
|
TenantAccountJoin.tenant_id != tenant_id,
|
|
|
|
|
).update({"current": False})
|
|
|
|
|
tenant_account_join.current = True
|
|
|
|
|
# Set the current tenant for the account
|
|
|
|
|
@ -730,18 +824,24 @@ class TenantService:
|
|
|
|
|
return (
|
|
|
|
|
db.session.query(TenantAccountJoin)
|
|
|
|
|
.filter(
|
|
|
|
|
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
|
|
|
|
|
TenantAccountJoin.tenant_id == tenant.id,
|
|
|
|
|
TenantAccountJoin.role.in_([role.value for role in roles]),
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
is not None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
|
|
|
|
|
def get_user_role(
|
|
|
|
|
account: Account, tenant: Tenant
|
|
|
|
|
) -> Optional[TenantAccountJoinRole]:
|
|
|
|
|
"""Get the role of the current account for a given tenant"""
|
|
|
|
|
join = (
|
|
|
|
|
db.session.query(TenantAccountJoin)
|
|
|
|
|
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
|
|
|
|
.filter(
|
|
|
|
|
TenantAccountJoin.tenant_id == tenant.id,
|
|
|
|
|
TenantAccountJoin.account_id == account.id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
return join.role if join else None
|
|
|
|
|
@ -752,7 +852,9 @@ class TenantService:
|
|
|
|
|
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
|
|
|
|
|
def check_member_permission(
|
|
|
|
|
tenant: Tenant, operator: Account, member: Account | None, action: str
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Check member permission"""
|
|
|
|
|
perms = {
|
|
|
|
|
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
|
|
|
|
@ -766,18 +868,26 @@ class TenantService:
|
|
|
|
|
if operator.id == member.id:
|
|
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
|
|
|
|
|
|
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
|
|
|
|
ta_operator = TenantAccountJoin.query.filter_by(
|
|
|
|
|
tenant_id=tenant.id, account_id=operator.id
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if not ta_operator or ta_operator.role not in perms[action]:
|
|
|
|
|
raise NoPermissionError(f"No permission to {action} member.")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
|
|
|
|
def remove_member_from_tenant(
|
|
|
|
|
tenant: Tenant, account: Account, operator: Account
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Remove member from tenant"""
|
|
|
|
|
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
|
|
|
|
|
if operator.id == account.id and TenantService.check_member_permission(
|
|
|
|
|
tenant, operator, account, "remove"
|
|
|
|
|
):
|
|
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
|
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id
|
|
|
|
|
).first()
|
|
|
|
|
if not ta:
|
|
|
|
|
raise MemberNotInTenantError("Member not in tenant.")
|
|
|
|
|
|
|
|
|
|
@ -785,18 +895,26 @@ class TenantService:
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
|
|
|
|
def update_member_role(
|
|
|
|
|
tenant: Tenant, member: Account, new_role: str, operator: Account
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Update member role"""
|
|
|
|
|
TenantService.check_member_permission(tenant, operator, member, "update")
|
|
|
|
|
|
|
|
|
|
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
|
|
|
|
target_member_join = TenantAccountJoin.query.filter_by(
|
|
|
|
|
tenant_id=tenant.id, account_id=member.id
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if target_member_join.role == new_role:
|
|
|
|
|
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
|
|
|
|
raise RoleAlreadyAssignedError(
|
|
|
|
|
"The provided role is already assigned to the member."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if new_role == "owner":
|
|
|
|
|
# Find the current owner and change their role to 'admin'
|
|
|
|
|
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
|
|
|
|
current_owner_join = TenantAccountJoin.query.filter_by(
|
|
|
|
|
tenant_id=tenant.id, role="owner"
|
|
|
|
|
).first()
|
|
|
|
|
current_owner_join.role = "admin"
|
|
|
|
|
|
|
|
|
|
# Update the role of the target member
|
|
|
|
|
@ -806,7 +924,9 @@ class TenantService:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
|
|
|
|
"""Dissolve tenant"""
|
|
|
|
|
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
|
|
|
|
|
if not TenantService.check_member_permission(
|
|
|
|
|
tenant, operator, operator, "remove"
|
|
|
|
|
):
|
|
|
|
|
raise NoPermissionError("No permission to dissolve tenant.")
|
|
|
|
|
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
|
|
|
|
db.session.delete(tenant)
|
|
|
|
|
@ -847,7 +967,9 @@ class RegisterService:
|
|
|
|
|
account.last_login_ip = ip_address
|
|
|
|
|
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
|
|
|
|
|
|
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
|
|
|
|
TenantService.create_owner_tenant_if_not_exist(
|
|
|
|
|
account=account, is_setup=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
|
|
|
|
|
db.session.add(dify_setup)
|
|
|
|
|
@ -891,7 +1013,10 @@ class RegisterService:
|
|
|
|
|
if open_id is not None and provider is not None:
|
|
|
|
|
AccountService.link_account_integrate(provider, open_id, account)
|
|
|
|
|
|
|
|
|
|
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
|
|
|
|
|
if (
|
|
|
|
|
FeatureService.get_system_features().is_allow_create_workspace
|
|
|
|
|
and create_workspace_required
|
|
|
|
|
):
|
|
|
|
|
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
|
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
|
|
account.current_tenant = tenant
|
|
|
|
|
@ -913,7 +1038,12 @@ class RegisterService:
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def invite_new_member(
|
|
|
|
|
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
|
|
|
|
|
cls,
|
|
|
|
|
tenant: Tenant,
|
|
|
|
|
email: str,
|
|
|
|
|
language: str,
|
|
|
|
|
role: str = "normal",
|
|
|
|
|
inviter: Optional[Account] = None,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Invite new member"""
|
|
|
|
|
account = Account.query.filter_by(email=email).first()
|
|
|
|
|
@ -924,14 +1054,20 @@ class RegisterService:
|
|
|
|
|
name = email.split("@")[0]
|
|
|
|
|
|
|
|
|
|
account = cls.register(
|
|
|
|
|
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
|
|
|
|
|
email=email,
|
|
|
|
|
name=name,
|
|
|
|
|
language=language,
|
|
|
|
|
status=AccountStatus.PENDING,
|
|
|
|
|
is_setup=True,
|
|
|
|
|
)
|
|
|
|
|
# Create new tenant member for invited tenant
|
|
|
|
|
TenantService.create_tenant_member(tenant, account, role)
|
|
|
|
|
TenantService.switch_tenant(account, tenant.id)
|
|
|
|
|
else:
|
|
|
|
|
TenantService.check_member_permission(tenant, inviter, account, "add")
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
|
|
|
|
ta = TenantAccountJoin.query.filter_by(
|
|
|
|
|
tenant_id=tenant.id, account_id=account.id
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
if not ta:
|
|
|
|
|
TenantService.create_tenant_member(tenant, account, role)
|
|
|
|
|
@ -962,7 +1098,11 @@ class RegisterService:
|
|
|
|
|
"workspace_id": tenant.id,
|
|
|
|
|
}
|
|
|
|
|
expiry_hours = dify_config.INVITE_EXPIRY_HOURS
|
|
|
|
|
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
|
|
|
|
|
redis_client.setex(
|
|
|
|
|
cls._get_invitation_token_key(token),
|
|
|
|
|
expiry_hours * 60 * 60,
|
|
|
|
|
json.dumps(invitation_data),
|
|
|
|
|
)
|
|
|
|
|
return token
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@ -974,7 +1114,9 @@ class RegisterService:
|
|
|
|
|
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
|
|
|
|
if workspace_id and email:
|
|
|
|
|
email_hash = sha256(email.encode()).hexdigest()
|
|
|
|
|
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
|
|
|
|
|
cache_key = "member_invite_token:{}, {}:{}".format(
|
|
|
|
|
workspace_id, email_hash, token
|
|
|
|
|
)
|
|
|
|
|
redis_client.delete(cache_key)
|
|
|
|
|
else:
|
|
|
|
|
redis_client.delete(cls._get_invitation_token_key(token))
|
|
|
|
|
@ -989,7 +1131,9 @@ class RegisterService:
|
|
|
|
|
|
|
|
|
|
tenant = (
|
|
|
|
|
db.session.query(Tenant)
|
|
|
|
|
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
|
|
|
|
.filter(
|
|
|
|
|
Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal"
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -999,7 +1143,10 @@ class RegisterService:
|
|
|
|
|
tenant_account = (
|
|
|
|
|
db.session.query(Account, TenantAccountJoin.role)
|
|
|
|
|
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
|
|
|
|
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
|
|
|
|
.filter(
|
|
|
|
|
Account.email == invitation_data["email"],
|
|
|
|
|
TenantAccountJoin.tenant_id == tenant.id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|