From c7e326aec809255523907e800883f18a2d9eb981 Mon Sep 17 00:00:00 2001 From: ytqh Date: Sat, 12 Apr 2025 18:59:31 +0800 Subject: [PATCH] feat: do not restrict health status --- api/controllers/admin/students/students.py | 2 -- api/models/model.py | 24 +++++++++++++++++++++- api/services/stats_service.py | 7 ++++--- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/api/controllers/admin/students/students.py b/api/controllers/admin/students/students.py index aad96e5e59..ed98a2b00d 100644 --- a/api/controllers/admin/students/students.py +++ b/api/controllers/admin/students/students.py @@ -122,8 +122,6 @@ class StudentList(Resource): # Build query filters filters = {} if health_status: - if health_status not in ['normal', 'potential', 'critical']: - return {"error": "Invalid health_status"}, 400 filters['health_status'] = health_status if begin_date: diff --git a/api/models/model.py b/api/models/model.py index 70c7afcdb5..34ac2c3f01 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1343,6 +1343,28 @@ class OperationLog(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) +class HealthStatus(StrEnum): + NORMAL = "normal" # 一般 + POTENTIAL = "potential" # 有轻微问题 + CRITICAL = "critical" # 有严重问题 + GOOD = "good" # 良好 + LACKDATA = "lackdata" # 数据不足 + + @staticmethod + def value_of(value: str) -> "HealthStatus": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in HealthStatus: + if mode.value == value: + return mode + + raise ValueError(f"invalid health status value: {value}") + + class EndUser(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "end_users" __table_args__ = ( @@ -1363,7 +1385,7 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined] session_id: Mapped[str] = mapped_column() gender = db.Column(db.Integer, nullable=False, server_default=db.text("0")) # 0: unknown, 1: male, 2: female profile_updated_at = db.Column(db.DateTime, nullable=True) # To record when profile was last updated - health_status = db.Column(db.String(255), nullable=True) # Only accept for "normal", "potential", "critical" + health_status = db.Column(db.String(255), nullable=True) extra_profile = db.Column( db.JSON, nullable=True ) # JSON format, e.g. { "major":"engineer", "topics":["math", "physics"], "summary": "This is a summary of the user's profile", "memory": "This is the user's memory"} diff --git a/api/services/stats_service.py b/api/services/stats_service.py index af5f118a8f..8ddfeac624 100644 --- a/api/services/stats_service.py +++ b/api/services/stats_service.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from extensions.ext_database import db from models import Conversation, EndUser, Message +from models.model import HealthStatus from sqlalchemy import and_, distinct, func @@ -22,7 +23,7 @@ class StatsService: Dictionary containing high risk user count and changes """ # Build query with filters - high_risk_query = db.session.query(EndUser).filter(EndUser.health_status == 'critical') + high_risk_query = db.session.query(EndUser).filter(EndUser.health_status == HealthStatus.CRITICAL.value) total_query = db.session.query(EndUser) # Apply app_id filter if provided @@ -41,7 +42,7 @@ class StatsService: # Get yesterday's count yesterday = datetime.now() - timedelta(days=1) yesterday_query = db.session.query(EndUser).filter( - EndUser.health_status == 'critical', EndUser.updated_at <= yesterday + EndUser.health_status == HealthStatus.CRITICAL.value, EndUser.updated_at <= yesterday ) # Apply app_id filter if provided @@ -57,7 +58,7 @@ class StatsService: # Get last week's count last_week = datetime.now() - timedelta(days=7) last_week_query = db.session.query(EndUser).filter( - EndUser.health_status == 'critical', EndUser.updated_at <= last_week + EndUser.health_status == HealthStatus.CRITICAL.value, EndUser.updated_at <= last_week ) # Apply app_id filter if provided