add login check in service api auth

pull/21891/head
ytqh 1 year ago
parent 29df704818
commit 7102fe396c

@ -4,21 +4,22 @@ import flask_login # type: ignore
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.service_api_with_auth import api from controllers.service_api_with_auth import api
from controllers.service_api_with_auth.auth.error import (EmailCodeError, from controllers.service_api_with_auth.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError
InvalidEmailError, from controllers.service_api_with_auth.error import (
InvalidTokenError) AccountInFreezeError,
from controllers.service_api_with_auth.error import (AccountInFreezeError,
AccountNotFound,
EmailSendIpLimitError, EmailSendIpLimitError,
TenantNotFoundError) OrganizationMismatchError,
OrganizationNotFoundError,
TenantNotFoundError,
)
from extensions.ext_database import db
from flask import request from flask import request
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from models.account import Account from models.account import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.organization_service import OrganizationService
from services.feature_service import FeatureService
class LogoutApi(Resource): class LogoutApi(Resource):
@ -189,19 +190,38 @@ class EmailCodeLoginApi(Resource):
if tenant is None: if tenant is None:
raise TenantNotFoundError() raise TenantNotFoundError()
# Find organization based on email domain
organization = OrganizationService.find_organization_by_email_domain(user_email, tenant.id)
if organization is None:
raise OrganizationNotFoundError()
is_new_user = account is None
if account is None: if account is None:
try:
# Create new account
account = AccountService.create_account_in_tenant( account = AccountService.create_account_in_tenant(
tenant=tenant, tenant=tenant,
email=user_email, email=user_email,
name=user_email, name=user_email,
interface_language=languages[0], interface_language=languages[0],
) )
is_new_user = True
except AccountRegisterError as are: # Assign organization if found
raise AccountInFreezeError() if organization:
OrganizationService.assign_account_to_organization(account, organization.id)
else: else:
is_new_user = False
if account.organization_id is not None and account.organization_id != organization.id:
raise OrganizationMismatchError()
# Update organization if needed
if organization:
OrganizationService.assign_account_to_organization(account, organization.id)
# Ensure account is member of tenant
connected_tenant = TenantService.get_join_tenants(account) connected_tenant = TenantService.get_join_tenants(account)
if connected_tenant is None or tenant not in connected_tenant: if connected_tenant is None or tenant not in connected_tenant:
TenantService.create_tenant_member(tenant, account, role="end_user") TenantService.create_tenant_member(tenant, account, role="end_user")

@ -102,7 +102,20 @@ class AccountInFreezeError(BaseHTTPException):
"and is temporarily unavailable for new account registration." "and is temporarily unavailable for new account registration."
) )
class TenantNotFoundError(BaseHTTPException): class TenantNotFoundError(BaseHTTPException):
error_code = "tenant_not_found" error_code = "tenant_not_found"
description = "Tenant not found." description = "Tenant not found."
code = 400 code = 400
class OrganizationNotFoundError(BaseHTTPException):
error_code = "organization_not_found"
description = "Organization not found."
code = 400
class OrganizationMismatchError(BaseHTTPException):
error_code = "organization_mismatch"
description = "Organization mismatch."
code = 400

@ -13,8 +13,10 @@ from libs.login import _get_user
from libs.passport import PassportService from libs.passport import PassportService
from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountJoinRole, TenantStatus from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountJoinRole, TenantStatus
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
from models.organization import Organization
from pydantic import BaseModel # type: ignore from pydantic import BaseModel # type: ignore
from services.account_service import AccountService from services.account_service import AccountService
from services.end_user_service import EndUserService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from sqlalchemy import select, update # type: ignore from sqlalchemy import select, update # type: ignore
from sqlalchemy.orm import Session # type: ignore from sqlalchemy.orm import Session # type: ignore
@ -241,29 +243,8 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
if not user_id: if not user_id:
user_id = "DEFAULT-USER" user_id = "DEFAULT-USER"
end_user = ( # Use EndUserService to get or create end user with organization awareness
db.session.query(EndUser) return EndUserService.get_or_create_end_user(app_model, user_id)
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.external_user_id == user_id,
EndUser.type == "service_api_with_auth",
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api_with_auth",
session_id=user_id,
external_user_id=user_id,
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource): class DatasetApiResource(Resource):

@ -2,7 +2,9 @@ from typing import Any, Dict, Optional, Tuple
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import MultiPagePagination from libs.infinite_scroll_pagination import MultiPagePagination
from models.account import Account
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from services.organization_service import OrganizationService
from sqlalchemy import and_, desc, func from sqlalchemy import and_, desc, func
@ -192,3 +194,110 @@ class EndUserService:
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
return False, str(e) return False, str(e)
@classmethod
def get_or_create_end_user(cls, app_model: App, user_id: str, user_type: str = "service_api_with_auth") -> EndUser:
"""
Get or create an end user with organization awareness
Args:
app_model: The app model
user_id: The external user ID (often an account ID)
user_type: The type of end user (default: service_api_with_auth)
Returns:
The end user
"""
if not user_id:
user_id = "DEFAULT-USER"
# Find existing end user
end_user = (
db.session.query(EndUser)
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.external_user_id == user_id,
EndUser.type == user_type,
)
.first()
)
# Get organization if the user has an account
organization_id = None
if user_id != "DEFAULT-USER":
account = db.session.query(Account).filter(Account.id == user_id).first()
if account:
organization = OrganizationService.get_organization_for_account_or_assign(account, app_model.tenant_id)
if organization:
organization_id = organization.id
if not end_user:
# Create new end user
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=user_type,
external_user_id=user_id,
session_id=user_id,
organization_id=organization_id,
)
db.session.add(end_user)
db.session.commit()
elif organization_id and end_user.organization_id != organization_id:
# Update organization if needed
OrganizationService.assign_end_user_to_organization(end_user, organization_id)
return end_user
@classmethod
def get_organization_for_end_user(cls, end_user: EndUser) -> Optional[dict]:
"""
Get organization info for an end user
Args:
end_user: The end user
Returns:
Organization info as dict or None
"""
if not end_user or not end_user.organization_id:
return None
organization = OrganizationService.get_organization_by_id(end_user.organization_id)
if organization:
return {
"id": organization.id,
"name": organization.name,
"code": organization.code,
"type": organization.type,
}
return None
@classmethod
def update_end_user(
cls, end_user: EndUser, name: Optional[str] = None, organization_id: Optional[str] = None
) -> EndUser:
"""
Update an end user's properties
Args:
end_user: The end user to update
name: New name (optional)
organization_id: New organization ID (optional)
Returns:
The updated end user
"""
if not end_user:
raise ValueError("End user cannot be None")
if name:
end_user.name = name
if organization_id and end_user.organization_id != organization_id:
OrganizationService.assign_end_user_to_organization(end_user, organization_id)
db.session.commit()
return end_user

@ -0,0 +1,196 @@
from typing import List, Optional, Union
from extensions.ext_database import db
from models.account import Account, Tenant
from models.model import EndUser
from models.organization import Organization, OrganizationMember, OrganizationRole
class OrganizationService:
"""Service for handling organization-related operations"""
@classmethod
def find_organization_by_email_domain(cls, email: str, tenant_id: str) -> Optional[Organization]:
"""
Find an organization that matches the email domain for a given tenant
Args:
email: The email to check
tenant_id: The tenant ID to search in
Returns:
Organization or None if no match found
"""
if not email or '@' not in email:
return None
# Get email domain
email_domain = email.split('@')[-1].lower()
# Get active organizations for this tenant
organizations = (
db.session.query(Organization)
.filter(Organization.tenant_id == tenant_id, Organization.status == 'active')
.all()
)
# Check each organization for matching email domain
for organization in organizations:
if organization.validate_email(email):
return organization
return None
@classmethod
def assign_account_to_organization(
cls, account: Account, organization_id: str, role: str = OrganizationRole.STUDENT
) -> bool:
"""
Assign an account to an organization and set it as the current organization
Args:
account: The account to assign
organization_id: The organization ID to assign to
role: The role to assign within the organization
Returns:
bool: True if successful, False otherwise
"""
if not account or not organization_id:
return False
# Check if organization exists
organization = db.session.query(Organization).filter(Organization.id == organization_id).first()
if not organization:
return False
# Update account's current organization
account.current_organization_id = organization_id
# Check if the account is already a member of this organization
existing_member = (
db.session.query(OrganizationMember)
.filter(OrganizationMember.organization_id == organization_id, OrganizationMember.account_id == account.id)
.first()
)
# If not a member, add them
if not existing_member:
member = OrganizationMember(
organization_id=organization_id,
account_id=account.id,
role=role,
is_default=True,
created_by=account.id,
)
db.session.add(member)
db.session.commit()
return True
@classmethod
def assign_end_user_to_organization(cls, end_user: EndUser, organization_id: str) -> bool:
"""
Assign an end user to an organization
Args:
end_user: The end user to assign
organization_id: The organization ID to assign to
Returns:
bool: True if successful, False otherwise
"""
if not end_user or not organization_id:
return False
# Check if organization exists
organization = db.session.query(Organization).filter(Organization.id == organization_id).first()
if not organization:
return False
# Update end user's organization
end_user.organization_id = organization_id
db.session.commit()
return True
@classmethod
def get_organization_for_account_or_assign(cls, account: Account, tenant_id: str) -> Optional[Organization]:
"""
Get the current organization for an account, or find and assign one based on email domain
Args:
account: The account to check
tenant_id: The tenant ID to search in
Returns:
Organization or None if no match found
"""
if not account:
return None
# If account already has an organization, return it
if account.current_organization_id:
return db.session.query(Organization).filter(Organization.id == account.current_organization_id).first()
# Otherwise, find an organization based on email domain
if account.email:
organization = cls.find_organization_by_email_domain(account.email, tenant_id)
if organization:
# Assign the account to this organization
cls.assign_account_to_organization(account, organization.id)
return organization
return None
@classmethod
def get_organization_for_end_user(cls, end_user: EndUser, tenant_id: str) -> Optional[Organization]:
"""
Get the organization for an end user, checking external account if needed
Args:
end_user: The end user to check
tenant_id: The tenant ID to search in
Returns:
Organization or None if no match found
"""
if not end_user:
return None
# If end user already has an organization, return it
if end_user.organization_id:
return db.session.query(Organization).filter(Organization.id == end_user.organization_id).first()
# If the end user has an external user ID that's an account, check that
if end_user.external_user_id and end_user.type == "service_api_with_auth":
account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first()
if account:
organization = cls.get_organization_for_account_or_assign(account, tenant_id)
if organization:
# Assign the end user to this organization
cls.assign_end_user_to_organization(end_user, organization.id)
return organization
return None
@classmethod
def get_available_organizations_for_tenant(cls, tenant_id: str) -> List[Organization]:
"""
Get all active organizations for a tenant
Args:
tenant_id: The tenant ID to search in
Returns:
List of organizations
"""
return (
db.session.query(Organization)
.filter(Organization.tenant_id == tenant_id, Organization.status == 'active')
.all()
)
@classmethod
def get_organization_by_id(cls, organization_id: str) -> Optional[Organization]:
"""Get an organization by ID"""
return db.session.query(Organization).filter(Organization.id == organization_id).first()
Loading…
Cancel
Save