diff --git a/api/controllers/admin/students/conversation.py b/api/controllers/admin/students/conversation.py index 28f3468205..807ea4bee6 100644 --- a/api/controllers/admin/students/conversation.py +++ b/api/controllers/admin/students/conversation.py @@ -8,7 +8,7 @@ from fields.conversation_fields import conversation_infinite_scroll_pagination_f from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful.inputs import int_range # type: ignore from libs.helper import uuid_value -from models.model import App, AppMode +from models.model import Account, App, AppMode from services.conversation_service import ConversationService from services.end_user_service import EndUserService from sqlalchemy.orm import Session # type: ignore @@ -18,7 +18,7 @@ from werkzeug.exceptions import NotFound class StudentConversation(Resource): @validate_admin_token_and_extract_info @marshal_with(conversation_infinite_scroll_pagination_fields) - def get(self, app_model: App, student_id: str): + def get(self, app_model: App, account: Account, student_id: str): """Get student's conversation history. --- tags: @@ -83,6 +83,10 @@ class StudentConversation(Resource): if not end_user: raise NotFound("Student not found") + # Ensure student belongs to admin's organization + if account.current_organization_id and end_user.organization_id != account.current_organization_id: + raise NotFound("Student not found in your organization") + app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() diff --git a/api/controllers/admin/students/message.py b/api/controllers/admin/students/message.py index 7e5be63042..89fa9cdac4 100644 --- a/api/controllers/admin/students/message.py +++ b/api/controllers/admin/students/message.py @@ -7,9 +7,8 @@ from fields.raws import FilesContainedField from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful.inputs import int_range # type: ignore from libs.helper import TimestampField, uuid_value -from models.model import App, AppMode, EndUser +from models.model import Account, App, AppMode from services.end_user_service import EndUserService -from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService from werkzeug.exceptions import NotFound @@ -73,7 +72,7 @@ class MessageListApi(Resource): @validate_admin_token_and_extract_info @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_model: App, student_id: str): + def get(self, app_model: App, account: Account, student_id: str): """Get messages list. --- tags: @@ -136,6 +135,10 @@ class MessageListApi(Resource): if not end_user: raise NotFound("Student not found") + # Ensure student belongs to admin's organization + if account.current_organization_id and end_user.organization_id != account.current_organization_id: + raise NotFound("Student not found in your organization") + parser = reqparse.RequestParser() parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args") diff --git a/api/controllers/admin/students/students.py b/api/controllers/admin/students/students.py index b3467706e5..5fb0076c52 100644 --- a/api/controllers/admin/students/students.py +++ b/api/controllers/admin/students/students.py @@ -3,14 +3,14 @@ from controllers.admin.wraps import validate_admin_token_and_extract_info from fields.end_user_fields import end_users_infinite_scroll_pagination_fields from flask import Blueprint from flask_restful import Api, Resource, marshal_with # type: ignore -from models.model import App +from models.model import Account, App from services.end_user_service import EndUserService class StudentList(Resource): @validate_admin_token_and_extract_info @marshal_with(end_users_infinite_scroll_pagination_fields) - def get(self, app_model: App): + def get(self, app_model: App, account: Account): """Get all end_user list related with the app_model with filters with pagination. --- tags: @@ -88,6 +88,8 @@ class StudentList(Resource): type: string major: type: string + organization_id: + type: string 401: description: Invalid or missing API key 400: @@ -133,12 +135,18 @@ class StudentList(Resource): # Get students with pagination offset = (page - 1) * limit - return EndUserService.pagination_by_filters(app_model=app_model, filters=filters, offset=offset, limit=limit) + # Get the organization ID from the account + organization_id = account.current_organization_id + + # Use the organization ID for filtering students by organization + return EndUserService.pagination_by_filters( + app_model=app_model, filters=filters, offset=offset, limit=limit, organization_id=organization_id + ) class StudentAnalysis(Resource): @validate_admin_token_and_extract_info - def get(self, app_model: App): + def get(self, app_model: App, account: Account): """Get AI analysis and intervention suggestions. --- tags: @@ -181,7 +189,7 @@ class StudentAnalysis(Resource): class StudentStatus(Resource): @validate_admin_token_and_extract_info - def put(self, app_model: App): + def put(self, app_model: App, account: Account): """Update student follow-up status. --- tags: @@ -234,7 +242,7 @@ class StudentStatus(Resource): pass @validate_admin_token_and_extract_info - def get(self, app_model: App): + def get(self, app_model: App, account: Account): """Get student follow-up status history. --- tags: @@ -282,7 +290,7 @@ class StudentStatus(Resource): class StudentNote(Resource): @validate_admin_token_and_extract_info - def put(self, app_model: App): + def put(self, app_model: App, account: Account): """Update student follow-up note. --- tags: @@ -331,7 +339,7 @@ class StudentNote(Resource): pass @validate_admin_token_and_extract_info - def get(self, app_model: App): + def get(self, app_model: App, account: Account): """Get student follow-up note history. --- tags: diff --git a/api/controllers/admin/wraps.py b/api/controllers/admin/wraps.py index 670237e611..34bfbe8c95 100644 --- a/api/controllers/admin/wraps.py +++ b/api/controllers/admin/wraps.py @@ -1,23 +1,14 @@ from collections.abc import Callable -from datetime import UTC, datetime, timedelta -from enum import Enum from functools import wraps from typing import Optional from configs import dify_config from extensions.ext_database import db -from flask import current_app, request -from flask_login import user_logged_in # type: ignore -from flask_restful import Resource # type: ignore -from libs.login import _get_user +from flask import request from libs.passport import PassportService -from models.account import Account, AccountStatus, Tenant, TenantAccountJoinRole, TenantStatus -from models.model import ApiToken, App, EndUser -from pydantic import BaseModel # type: ignore +from models.account import AccountStatus, Tenant, TenantAccountJoinRole, TenantStatus +from models.model import App from services.account_service import AccountService -from services.feature_service import FeatureService -from sqlalchemy import select, update # type: ignore -from sqlalchemy.orm import Session # type: ignore from werkzeug.exceptions import Forbidden, Unauthorized @@ -75,7 +66,9 @@ def validate_admin_token_and_extract_info(view: Optional[Callable] = None): if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") + # Pass account and app_model to the view kwargs["app_model"] = app_model + kwargs["account"] = account return view_func(*args, **kwargs) diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index b36a064e75..bd90d67349 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -10,14 +10,18 @@ from sqlalchemy import and_, desc, func class EndUserService: @staticmethod - def pagination_by_filters(app_model: App, filters: Dict[str, Any], offset: int, limit: int) -> MultiPagePagination: + def pagination_by_filters( + app_model: App, filters: Dict[str, Any], offset: int, limit: int, organization_id: Optional[str] = None + ) -> MultiPagePagination: """ Get a list of end users with filtering and pagination Args: + app_model: The app model filters: Dictionary containing filter criteria offset: Number of records to skip limit: Maximum number of records to return + organization_id: Optional organization ID to filter users by Returns: Dictionary containing total count and list of end users @@ -63,6 +67,10 @@ class EndUserService: .filter(EndUser.external_user_id != None) ) + # Filter by organization if specified + if organization_id: + query = query.filter(EndUser.organization_id == organization_id) + # Apply filters filter_conditions = [] @@ -110,6 +118,7 @@ class EndUserService: 'topics': end_user.topics, 'summary': end_user.summary, 'major': end_user.major, + 'organization_id': end_user.organization_id, } users.append(end_user_dict)