from typing import List, Dict, Any from sqlmodel import select, Session from crud.config.db import engine from crud.entity.models import PointSceneWorldItem, PointScene from crud.entity.scene_dto import SaveWorldItem class SceneService: @classmethod def get_scene_info(cls,s=None): result_list = [] with Session(engine) as session: query_stmt = select(PointScene) if s is not None: query_stmt = query_stmt.where(PointScene.scene_name == s) scene_all = session.exec(query_stmt) for scene in scene_all: query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene.scene_name) query_result_item = session.exec(query_stmt) scene_worlds = query_result_item.all() item = { "scene": scene.scene_name, "frames": [scene_world.frame for scene_world in scene_worlds], "lidar_ext": scene.lidar_ext, "camera_ext": scene.camera_ext, "radar_ext": scene.radar_ext, "aux_lidar_ext": scene.aux_lidar_ext, "boxtype": scene.box_type, "camera": [ "right", "left", "front" ], "calib": scene.calib_json, } result_list.append(item) return result_list @classmethod def save_world_list(cls, items: List[SaveWorldItem]): """批量保存标注数据""" for item in items: scene = item.scene frame = item.frame ann = item.annotation with Session(engine) as session: query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() result_item.label_json = ann session.add(result_item) session.commit() return "ok" @classmethod def update_label_json(cls, scene, frame,ann): with Session(engine) as session: query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() result_item.label_json = ann session.add(result_item) session.commit() return "ok" @classmethod def get_frame_ids(cls, scene: str): with Session(engine) as session: exec_result = session.exec(select(PointSceneWorldItem.id).where(PointSceneWorldItem.scene_name == scene)) ids = exec_result.all() return ids @classmethod def get_scene_items(cls, scene: str): with Session(engine) as session: items = session.exec(select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene)).all() return items @classmethod def get_label_json(cls, scene, frame): with Session(engine) as session: query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() return result_item.label_json if result_item else [] @classmethod def get_ego_pose_json(cls, scene, frame): with Session(engine) as session: query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame) result_item = session.exec(query_stmt).first() return result_item.ego_pose_json if result_item else None @classmethod def get_scene_names(cls,s=None): with Session(engine) as session: query_stmt = select(PointScene) if s is not None: query_stmt = query_stmt.where(PointScene.scene_name == s) query_result_item = session.exec(query_stmt) return [i.scene_name for i in query_result_item.all()] @classmethod def get_scene_desc(cls): with Session(engine) as session: query_stmt = select(PointScene) query_result_item = session.exec(query_stmt) result = {} for i in query_result_item.all(): result[i.scene_name] = i.desc return result @classmethod def get_all_objs(cls,scene: str) -> List[Dict[str, Any]]: """从 CherryPy 类中提取的辅助函数""" scene_items:List[PointSceneWorldItem] = cls.get_scene_items(scene) if scene_items is None: return [] all_objs = {} for f in scene_items: boxes = f.label_json for b in boxes: o = {"category": b.get("obj_type"), "id": b.get("obj_id")} if not o["category"] or not o["id"]: continue k = f"{o['category']}-{o['id']}" if k in all_objs: all_objs[k]['count'] += 1 else: all_objs[k] = { "category": o["category"], "id": o["id"], "count": 1 } return list(all_objs.values()) if __name__ == '__main__': # print(SceneService.get_frame_ids("example")) # print(SceneService.get_scene_items("example")) print(SceneService.get_scene_info())