diff --git a/api/controllers/service_api_with_auth/auth/login.py b/api/controllers/service_api_with_auth/auth/login.py index b780f2085c..6fa4a6d85f 100644 --- a/api/controllers/service_api_with_auth/auth/login.py +++ b/api/controllers/service_api_with_auth/auth/login.py @@ -4,21 +4,22 @@ import flask_login # type: ignore from configs import dify_config from constants.languages import languages from controllers.service_api_with_auth import api -from controllers.service_api_with_auth.auth.error import (EmailCodeError, - InvalidEmailError, - InvalidTokenError) -from controllers.service_api_with_auth.error import (AccountInFreezeError, - AccountNotFound, - EmailSendIpLimitError, - TenantNotFoundError) +from controllers.service_api_with_auth.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError +from controllers.service_api_with_auth.error import ( + AccountInFreezeError, + EmailSendIpLimitError, + OrganizationMismatchError, + OrganizationNotFoundError, + TenantNotFoundError, +) +from extensions.ext_database import db from flask import request from flask_restful import Resource, reqparse # type: ignore from libs.helper import email, extract_remote_ip from models.account import Account from services.account_service import AccountService, TenantService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError -from services.feature_service import FeatureService +from services.organization_service import OrganizationService class LogoutApi(Resource): @@ -189,19 +190,38 @@ class EmailCodeLoginApi(Resource): if tenant is None: raise TenantNotFoundError() + # Find organization based on email domain + organization = OrganizationService.find_organization_by_email_domain(user_email, tenant.id) + + if organization is None: + raise OrganizationNotFoundError() + + is_new_user = account is None + if account is None: - try: - account = AccountService.create_account_in_tenant( - tenant=tenant, - email=user_email, - name=user_email, - interface_language=languages[0], - ) - is_new_user = True - except AccountRegisterError as are: - raise AccountInFreezeError() + + # Create new account + account = AccountService.create_account_in_tenant( + tenant=tenant, + email=user_email, + name=user_email, + interface_language=languages[0], + ) + + # Assign organization if found + if organization: + OrganizationService.assign_account_to_organization(account, organization.id) + else: - is_new_user = False + + if account.organization_id is not None and account.organization_id != organization.id: + raise OrganizationMismatchError() + + # Update organization if needed + if organization: + OrganizationService.assign_account_to_organization(account, organization.id) + + # Ensure account is member of tenant connected_tenant = TenantService.get_join_tenants(account) if connected_tenant is None or tenant not in connected_tenant: TenantService.create_tenant_member(tenant, account, role="end_user") diff --git a/api/controllers/service_api_with_auth/error.py b/api/controllers/service_api_with_auth/error.py index c17f3cc0c9..4b6603860b 100644 --- a/api/controllers/service_api_with_auth/error.py +++ b/api/controllers/service_api_with_auth/error.py @@ -102,7 +102,20 @@ class AccountInFreezeError(BaseHTTPException): "and is temporarily unavailable for new account registration." ) + class TenantNotFoundError(BaseHTTPException): error_code = "tenant_not_found" description = "Tenant not found." code = 400 + + +class OrganizationNotFoundError(BaseHTTPException): + error_code = "organization_not_found" + description = "Organization not found." + code = 400 + + +class OrganizationMismatchError(BaseHTTPException): + error_code = "organization_mismatch" + description = "Organization mismatch." + code = 400 diff --git a/api/controllers/service_api_with_auth/wraps.py b/api/controllers/service_api_with_auth/wraps.py index 133c1f0d0c..264bdbc7f5 100644 --- a/api/controllers/service_api_with_auth/wraps.py +++ b/api/controllers/service_api_with_auth/wraps.py @@ -13,8 +13,10 @@ from libs.login import _get_user from libs.passport import PassportService from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountJoinRole, TenantStatus from models.model import ApiToken, App, EndUser +from models.organization import Organization from pydantic import BaseModel # type: ignore from services.account_service import AccountService +from services.end_user_service import EndUserService from services.feature_service import FeatureService from sqlalchemy import select, update # type: ignore from sqlalchemy.orm import Session # type: ignore @@ -241,29 +243,8 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] if not user_id: user_id = "DEFAULT-USER" - end_user = ( - db.session.query(EndUser) - .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.external_user_id == user_id, - EndUser.type == "service_api_with_auth", - ) - .first() - ) - - if end_user is None: - end_user = EndUser( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - type="service_api_with_auth", - session_id=user_id, - external_user_id=user_id, - ) - db.session.add(end_user) - db.session.commit() - - return end_user + # Use EndUserService to get or create end user with organization awareness + return EndUserService.get_or_create_end_user(app_model, user_id) class DatasetApiResource(Resource): diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index 063916febb..b36a064e75 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -2,7 +2,9 @@ from typing import Any, Dict, Optional, Tuple from extensions.ext_database import db from libs.infinite_scroll_pagination import MultiPagePagination +from models.account import Account from models.model import App, Conversation, EndUser, Message +from services.organization_service import OrganizationService from sqlalchemy import and_, desc, func @@ -192,3 +194,110 @@ class EndUserService: except Exception as e: db.session.rollback() return False, str(e) + + @classmethod + def get_or_create_end_user(cls, app_model: App, user_id: str, user_type: str = "service_api_with_auth") -> EndUser: + """ + Get or create an end user with organization awareness + + Args: + app_model: The app model + user_id: The external user ID (often an account ID) + user_type: The type of end user (default: service_api_with_auth) + + Returns: + The end user + """ + if not user_id: + user_id = "DEFAULT-USER" + + # Find existing end user + end_user = ( + db.session.query(EndUser) + .filter( + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.external_user_id == user_id, + EndUser.type == user_type, + ) + .first() + ) + + # Get organization if the user has an account + organization_id = None + if user_id != "DEFAULT-USER": + account = db.session.query(Account).filter(Account.id == user_id).first() + if account: + organization = OrganizationService.get_organization_for_account_or_assign(account, app_model.tenant_id) + if organization: + organization_id = organization.id + + if not end_user: + # Create new end user + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=user_type, + external_user_id=user_id, + session_id=user_id, + organization_id=organization_id, + ) + db.session.add(end_user) + db.session.commit() + elif organization_id and end_user.organization_id != organization_id: + # Update organization if needed + OrganizationService.assign_end_user_to_organization(end_user, organization_id) + + return end_user + + @classmethod + def get_organization_for_end_user(cls, end_user: EndUser) -> Optional[dict]: + """ + Get organization info for an end user + + Args: + end_user: The end user + + Returns: + Organization info as dict or None + """ + if not end_user or not end_user.organization_id: + return None + + organization = OrganizationService.get_organization_by_id(end_user.organization_id) + if organization: + return { + "id": organization.id, + "name": organization.name, + "code": organization.code, + "type": organization.type, + } + + return None + + @classmethod + def update_end_user( + cls, end_user: EndUser, name: Optional[str] = None, organization_id: Optional[str] = None + ) -> EndUser: + """ + Update an end user's properties + + Args: + end_user: The end user to update + name: New name (optional) + organization_id: New organization ID (optional) + + Returns: + The updated end user + """ + if not end_user: + raise ValueError("End user cannot be None") + + if name: + end_user.name = name + + if organization_id and end_user.organization_id != organization_id: + OrganizationService.assign_end_user_to_organization(end_user, organization_id) + + db.session.commit() + return end_user diff --git a/api/services/organization_service.py b/api/services/organization_service.py new file mode 100644 index 0000000000..e611cdc0c5 --- /dev/null +++ b/api/services/organization_service.py @@ -0,0 +1,196 @@ +from typing import List, Optional, Union + +from extensions.ext_database import db +from models.account import Account, Tenant +from models.model import EndUser +from models.organization import Organization, OrganizationMember, OrganizationRole + + +class OrganizationService: + """Service for handling organization-related operations""" + + @classmethod + def find_organization_by_email_domain(cls, email: str, tenant_id: str) -> Optional[Organization]: + """ + Find an organization that matches the email domain for a given tenant + + Args: + email: The email to check + tenant_id: The tenant ID to search in + + Returns: + Organization or None if no match found + """ + if not email or '@' not in email: + return None + + # Get email domain + email_domain = email.split('@')[-1].lower() + + # Get active organizations for this tenant + organizations = ( + db.session.query(Organization) + .filter(Organization.tenant_id == tenant_id, Organization.status == 'active') + .all() + ) + + # Check each organization for matching email domain + for organization in organizations: + if organization.validate_email(email): + return organization + + return None + + @classmethod + def assign_account_to_organization( + cls, account: Account, organization_id: str, role: str = OrganizationRole.STUDENT + ) -> bool: + """ + Assign an account to an organization and set it as the current organization + + Args: + account: The account to assign + organization_id: The organization ID to assign to + role: The role to assign within the organization + + Returns: + bool: True if successful, False otherwise + """ + if not account or not organization_id: + return False + + # Check if organization exists + organization = db.session.query(Organization).filter(Organization.id == organization_id).first() + if not organization: + return False + + # Update account's current organization + account.current_organization_id = organization_id + + # Check if the account is already a member of this organization + existing_member = ( + db.session.query(OrganizationMember) + .filter(OrganizationMember.organization_id == organization_id, OrganizationMember.account_id == account.id) + .first() + ) + + # If not a member, add them + if not existing_member: + member = OrganizationMember( + organization_id=organization_id, + account_id=account.id, + role=role, + is_default=True, + created_by=account.id, + ) + db.session.add(member) + + db.session.commit() + return True + + @classmethod + def assign_end_user_to_organization(cls, end_user: EndUser, organization_id: str) -> bool: + """ + Assign an end user to an organization + + Args: + end_user: The end user to assign + organization_id: The organization ID to assign to + + Returns: + bool: True if successful, False otherwise + """ + if not end_user or not organization_id: + return False + + # Check if organization exists + organization = db.session.query(Organization).filter(Organization.id == organization_id).first() + if not organization: + return False + + # Update end user's organization + end_user.organization_id = organization_id + db.session.commit() + return True + + @classmethod + def get_organization_for_account_or_assign(cls, account: Account, tenant_id: str) -> Optional[Organization]: + """ + Get the current organization for an account, or find and assign one based on email domain + + Args: + account: The account to check + tenant_id: The tenant ID to search in + + Returns: + Organization or None if no match found + """ + if not account: + return None + + # If account already has an organization, return it + if account.current_organization_id: + return db.session.query(Organization).filter(Organization.id == account.current_organization_id).first() + + # Otherwise, find an organization based on email domain + if account.email: + organization = cls.find_organization_by_email_domain(account.email, tenant_id) + if organization: + # Assign the account to this organization + cls.assign_account_to_organization(account, organization.id) + return organization + + return None + + @classmethod + def get_organization_for_end_user(cls, end_user: EndUser, tenant_id: str) -> Optional[Organization]: + """ + Get the organization for an end user, checking external account if needed + + Args: + end_user: The end user to check + tenant_id: The tenant ID to search in + + Returns: + Organization or None if no match found + """ + if not end_user: + return None + + # If end user already has an organization, return it + if end_user.organization_id: + return db.session.query(Organization).filter(Organization.id == end_user.organization_id).first() + + # If the end user has an external user ID that's an account, check that + if end_user.external_user_id and end_user.type == "service_api_with_auth": + account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first() + if account: + organization = cls.get_organization_for_account_or_assign(account, tenant_id) + if organization: + # Assign the end user to this organization + cls.assign_end_user_to_organization(end_user, organization.id) + return organization + + return None + + @classmethod + def get_available_organizations_for_tenant(cls, tenant_id: str) -> List[Organization]: + """ + Get all active organizations for a tenant + + Args: + tenant_id: The tenant ID to search in + + Returns: + List of organizations + """ + return ( + db.session.query(Organization) + .filter(Organization.tenant_id == tenant_id, Organization.status == 'active') + .all() + ) + + @classmethod + def get_organization_by_id(cls, organization_id: str) -> Optional[Organization]: + """Get an organization by ID""" + return db.session.query(Organization).filter(Organization.id == organization_id).first()