Enhance student management by adding organization filtering. Updated API endpoints to include account information and ensure students belong to the admin's organization. Modified EndUserService to support organization-based pagination.

pull/21891/head
ytqh 1 year ago
parent a947049d7f
commit 1d72b70d82

@ -8,7 +8,7 @@ from fields.conversation_fields import conversation_infinite_scroll_pagination_f
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, AppMode from models.model import Account, App, AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.end_user_service import EndUserService from services.end_user_service import EndUserService
from sqlalchemy.orm import Session # type: ignore from sqlalchemy.orm import Session # type: ignore
@ -18,7 +18,7 @@ from werkzeug.exceptions import NotFound
class StudentConversation(Resource): class StudentConversation(Resource):
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, student_id: str): def get(self, app_model: App, account: Account, student_id: str):
"""Get student's conversation history. """Get student's conversation history.
--- ---
tags: tags:
@ -83,6 +83,10 @@ class StudentConversation(Resource):
if not end_user: if not end_user:
raise NotFound("Student not found") raise NotFound("Student not found")
# Ensure student belongs to admin's organization
if account.current_organization_id and end_user.organization_id != account.current_organization_id:
raise NotFound("Student not found in your organization")
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()

@ -7,9 +7,8 @@ from fields.raws import FilesContainedField
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser from models.model import Account, App, AppMode
from services.end_user_service import EndUserService from services.end_user_service import EndUserService
from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService from services.message_service import MessageService
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -73,7 +72,7 @@ class MessageListApi(Resource):
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, student_id: str): def get(self, app_model: App, account: Account, student_id: str):
"""Get messages list. """Get messages list.
--- ---
tags: tags:
@ -136,6 +135,10 @@ class MessageListApi(Resource):
if not end_user: if not end_user:
raise NotFound("Student not found") raise NotFound("Student not found")
# Ensure student belongs to admin's organization
if account.current_organization_id and end_user.organization_id != account.current_organization_id:
raise NotFound("Student not found in your organization")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args")

@ -3,14 +3,14 @@ from controllers.admin.wraps import validate_admin_token_and_extract_info
from fields.end_user_fields import end_users_infinite_scroll_pagination_fields from fields.end_user_fields import end_users_infinite_scroll_pagination_fields
from flask import Blueprint from flask import Blueprint
from flask_restful import Api, Resource, marshal_with # type: ignore from flask_restful import Api, Resource, marshal_with # type: ignore
from models.model import App from models.model import Account, App
from services.end_user_service import EndUserService from services.end_user_service import EndUserService
class StudentList(Resource): class StudentList(Resource):
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
@marshal_with(end_users_infinite_scroll_pagination_fields) @marshal_with(end_users_infinite_scroll_pagination_fields)
def get(self, app_model: App): def get(self, app_model: App, account: Account):
"""Get all end_user list related with the app_model with filters with pagination. """Get all end_user list related with the app_model with filters with pagination.
--- ---
tags: tags:
@ -88,6 +88,8 @@ class StudentList(Resource):
type: string type: string
major: major:
type: string type: string
organization_id:
type: string
401: 401:
description: Invalid or missing API key description: Invalid or missing API key
400: 400:
@ -133,12 +135,18 @@ class StudentList(Resource):
# Get students with pagination # Get students with pagination
offset = (page - 1) * limit offset = (page - 1) * limit
return EndUserService.pagination_by_filters(app_model=app_model, filters=filters, offset=offset, limit=limit) # Get the organization ID from the account
organization_id = account.current_organization_id
# Use the organization ID for filtering students by organization
return EndUserService.pagination_by_filters(
app_model=app_model, filters=filters, offset=offset, limit=limit, organization_id=organization_id
)
class StudentAnalysis(Resource): class StudentAnalysis(Resource):
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
def get(self, app_model: App): def get(self, app_model: App, account: Account):
"""Get AI analysis and intervention suggestions. """Get AI analysis and intervention suggestions.
--- ---
tags: tags:
@ -181,7 +189,7 @@ class StudentAnalysis(Resource):
class StudentStatus(Resource): class StudentStatus(Resource):
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
def put(self, app_model: App): def put(self, app_model: App, account: Account):
"""Update student follow-up status. """Update student follow-up status.
--- ---
tags: tags:
@ -234,7 +242,7 @@ class StudentStatus(Resource):
pass pass
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
def get(self, app_model: App): def get(self, app_model: App, account: Account):
"""Get student follow-up status history. """Get student follow-up status history.
--- ---
tags: tags:
@ -282,7 +290,7 @@ class StudentStatus(Resource):
class StudentNote(Resource): class StudentNote(Resource):
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
def put(self, app_model: App): def put(self, app_model: App, account: Account):
"""Update student follow-up note. """Update student follow-up note.
--- ---
tags: tags:
@ -331,7 +339,7 @@ class StudentNote(Resource):
pass pass
@validate_admin_token_and_extract_info @validate_admin_token_and_extract_info
def get(self, app_model: App): def get(self, app_model: App, account: Account):
"""Get student follow-up note history. """Get student follow-up note history.
--- ---
tags: tags:

@ -1,23 +1,14 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
from configs import dify_config from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
from flask import current_app, request from flask import request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from libs.login import _get_user
from libs.passport import PassportService from libs.passport import PassportService
from models.account import Account, AccountStatus, Tenant, TenantAccountJoinRole, TenantStatus from models.account import AccountStatus, Tenant, TenantAccountJoinRole, TenantStatus
from models.model import ApiToken, App, EndUser from models.model import App
from pydantic import BaseModel # type: ignore
from services.account_service import AccountService from services.account_service import AccountService
from services.feature_service import FeatureService
from sqlalchemy import select, update # type: ignore
from sqlalchemy.orm import Session # type: ignore
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
@ -75,7 +66,9 @@ def validate_admin_token_and_extract_info(view: Optional[Callable] = None):
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.") raise Forbidden("The workspace's status is archived.")
# Pass account and app_model to the view
kwargs["app_model"] = app_model kwargs["app_model"] = app_model
kwargs["account"] = account
return view_func(*args, **kwargs) return view_func(*args, **kwargs)

@ -10,14 +10,18 @@ from sqlalchemy import and_, desc, func
class EndUserService: class EndUserService:
@staticmethod @staticmethod
def pagination_by_filters(app_model: App, filters: Dict[str, Any], offset: int, limit: int) -> MultiPagePagination: def pagination_by_filters(
app_model: App, filters: Dict[str, Any], offset: int, limit: int, organization_id: Optional[str] = None
) -> MultiPagePagination:
""" """
Get a list of end users with filtering and pagination Get a list of end users with filtering and pagination
Args: Args:
app_model: The app model
filters: Dictionary containing filter criteria filters: Dictionary containing filter criteria
offset: Number of records to skip offset: Number of records to skip
limit: Maximum number of records to return limit: Maximum number of records to return
organization_id: Optional organization ID to filter users by
Returns: Returns:
Dictionary containing total count and list of end users Dictionary containing total count and list of end users
@ -63,6 +67,10 @@ class EndUserService:
.filter(EndUser.external_user_id != None) .filter(EndUser.external_user_id != None)
) )
# Filter by organization if specified
if organization_id:
query = query.filter(EndUser.organization_id == organization_id)
# Apply filters # Apply filters
filter_conditions = [] filter_conditions = []
@ -110,6 +118,7 @@ class EndUserService:
'topics': end_user.topics, 'topics': end_user.topics,
'summary': end_user.summary, 'summary': end_user.summary,
'major': end_user.major, 'major': end_user.major,
'organization_id': end_user.organization_id,
} }
users.append(end_user_dict) users.append(end_user_dict)

Loading…
Cancel
Save