From 12fcc7a3ac7d68695dcb36765dfcd95d24406c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=9F=8F=E8=B1=AA?= <> Date: Tue, 13 Jan 2026 17:32:44 +0800 Subject: [PATCH] feat: rename table --- crud/entity/models.py | 4 ++-- crud/service/scene_service.py | 32 ++++++++++++++++---------------- tools/check_labels.py | 4 ++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/crud/entity/models.py b/crud/entity/models.py index ae6cbd7..3899511 100644 --- a/crud/entity/models.py +++ b/crud/entity/models.py @@ -10,7 +10,7 @@ from crud.config.db import engine def get_uuid(): return uuid.uuid1().hex -class Scene(SQLModel,table=True): +class PointScene(SQLModel, table=True): id: str = Field(default_factory=get_uuid, max_length=32, primary_key=True) scene_name: str = Field(max_length=256, description="事件名称", unique=True) calib_json: List[dict] = Field(sa_type=JSON, nullable=True, description="calib是从点云到图像的校准矩阵。它是可选的,但如果提供,该框将投影在图像上以帮助注释。") @@ -22,7 +22,7 @@ class Scene(SQLModel,table=True): desc: List[dict] = Field(sa_type=JSON, nullable=True, description="calib是从点云到图像的校准矩阵。它是可选的,但如果提供,该框将投影在图像上以帮助注释。") -class SceneWorldItem(SQLModel,table=True): +class PointSceneWorldItem(SQLModel, table=True): __tablename__ = "scene_world_item" id: str = Field(default_factory=get_uuid, max_length=32, primary_key=True) scene_id: str = Field(max_length=256, description="事件id") diff --git a/crud/service/scene_service.py b/crud/service/scene_service.py index 4a94050..a4379c4 100644 --- a/crud/service/scene_service.py +++ b/crud/service/scene_service.py @@ -3,7 +3,7 @@ from typing import List, Dict, Any from sqlmodel import select, Session from crud.config.db import engine -from crud.entity.models import SceneWorldItem, Scene +from crud.entity.models import PointSceneWorldItem, PointScene from crud.entity.scene_dto import SaveWorldItem @@ -13,12 +13,12 @@ class SceneService: result_list = [] with Session(engine) as session: - query_stmt = select(Scene) + query_stmt = select(PointScene) if s is not None: - query_stmt = query_stmt.where(Scene.scene_name == s) + query_stmt = query_stmt.where(PointScene.scene_name == s) scene_all = session.exec(query_stmt) for scene in scene_all: - query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene.scene_name) + query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene.scene_name) query_result_item = session.exec(query_stmt) scene_worlds = query_result_item.all() item = { @@ -47,8 +47,8 @@ class SceneService: frame = item.frame ann = item.annotation with Session(engine) as session: - query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, - SceneWorldItem.frame == frame) + query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, + PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() result_item.label_json = ann session.add(result_item) @@ -57,8 +57,8 @@ class SceneService: @classmethod def update_label_json(cls, scene, frame,ann): with Session(engine) as session: - query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, - SceneWorldItem.frame == frame) + query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, + PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() result_item.label_json = ann session.add(result_item) @@ -69,43 +69,43 @@ class SceneService: def get_frame_ids(cls, scene: str): with Session(engine) as session: - exec_result = session.exec(select(SceneWorldItem.id).where(SceneWorldItem.scene_name == scene)) + exec_result = session.exec(select(PointSceneWorldItem.id).where(PointSceneWorldItem.scene_name == scene)) ids = exec_result.all() return ids @classmethod def get_scene_items(cls, scene: str): with Session(engine) as session: - items = session.exec(select(SceneWorldItem).where(SceneWorldItem.scene_name == scene)).all() + items = session.exec(select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene)).all() return items @classmethod def get_label_json(cls, scene, frame): with Session(engine) as session: - query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame) + query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() return result_item.label_json if result_item else [] @classmethod def get_ego_pose_json(cls, scene, frame): with Session(engine) as session: - query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame) + query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() return result_item.ego_pose_json if result_item else None @classmethod def get_scene_names(cls,s=None): with Session(engine) as session: - query_stmt = select(Scene) + query_stmt = select(PointScene) if s is not None: - query_stmt = query_stmt.where(Scene.scene_name == s) + query_stmt = query_stmt.where(PointScene.scene_name == s) query_result_item = session.exec(query_stmt) return [i.scene_name for i in query_result_item.all()] @classmethod def get_scene_desc(cls): with Session(engine) as session: - query_stmt = select(Scene) + query_stmt = select(PointScene) query_result_item = session.exec(query_stmt) result = {} for i in query_result_item.all(): @@ -115,7 +115,7 @@ class SceneService: @classmethod def get_all_objs(cls,scene: str) -> List[Dict[str, Any]]: """从 CherryPy 类中提取的辅助函数""" - scene_items:List[SceneWorldItem] = cls.get_scene_items(scene) + scene_items:List[PointSceneWorldItem] = cls.get_scene_items(scene) if scene_items is None: return [] all_objs = {} diff --git a/tools/check_labels.py b/tools/check_labels.py index e4e8732..3313147 100644 --- a/tools/check_labels.py +++ b/tools/check_labels.py @@ -2,7 +2,7 @@ from typing import List import numpy as np -from crud.entity.models import SceneWorldItem +from crud.entity.models import PointSceneWorldItem from crud.service.scene_service import SceneService @@ -44,7 +44,7 @@ class LabelChecker: # files = os.listdir(label_folder) labels = {} obj_ids = {} - scene_items:List[SceneWorldItem] = SceneService.get_scene_items(self.path) + scene_items:List[PointSceneWorldItem] = SceneService.get_scene_items(self.path) for s in scene_items: l = s.label_json frame_id = s.frame