import json import os from typing import List, Dict, Any from fastapi import APIRouter, HTTPException,Request from pydantic import BaseModel, Field import scene_reader from tools import check_labels as check from algos import pre_annotate class SaveWorldItem(BaseModel): scene: str frame: str annotation: Any class CropSceneRequest(BaseModel): rawSceneId: str startTime: str seconds: str desc: str class PredictRotationRequest(BaseModel): points: Any # 假设是 N*3 的点云数据 # 1. 定义模型 class LoadWorldItem(BaseModel): scene: str = Field(description="这是scene") frame: str = Field(description="这是frame") router = APIRouter() @router.post("/saveworldlist") async def saveworldlist(items: List[SaveWorldItem]): """批量保存标注数据""" for item in items: scene = item.scene frame = item.frame ann = item.annotation label_dir = os.path.join("./data", scene, "label") os.makedirs(label_dir, exist_ok=True) # 确保目录存在 file_path = os.path.join(label_dir, f"{frame}.json") # todo 标注存入数据库 with open(file_path, 'w') as f: json.dump(ann, f, indent=2, sort_keys=True) return "ok" @router.post("/cropscene") async def cropscene(request_data: CropSceneRequest): """裁剪场景""" rawdata = request_data.rawSceneId timestamp = rawdata.split("_")[0] log_file = f"temp/crop-scene-{timestamp}.log" # 注意:os.system 存在安全风险,在生产环境中应替换为更安全的 subprocess 模块 cmd = ( f"python ./tools/dataset_preprocess/crop_scene.py generate " f"{rawdata[0:10]}/{timestamp}_preprocessed/dataset_2hz - " f"{request_data.startTime} {request_data.seconds} " f'"{request_data.desc}" > {log_file} 2>&1' ) print(f"Executing command: {cmd}") code = os.system(cmd) log = [] if os.path.exists(log_file): with open(log_file) as f: log = [s.strip() for s in f.readlines()] os.remove(log_file) return {"code": code, "log": log} @router.get("/checkscene") def checkscene(scene: str): """检查场景的标注""" ck = check.LabelChecker(os.path.join("./data", scene)) ck.check() print(ck.messages) return ck.messages @router.post("/predict_rotation") async def predict_rotation(request_data: PredictRotationRequest): """预测旋转角度""" # FastAPI 自动将请求体转换为 Pydantic 模型 return {"angle": pre_annotate.predict_yaw(request_data.points)} @router.get("/auto_annotate") def auto_annotate(scene: str, frame: str): """自动标注""" print(f"Auto annotate {scene}, {frame}") file_path = f'./data/{scene}/lidar/{frame}.pcd' if not os.path.exists(file_path): raise HTTPException(status_code=404, detail=f"File not found: {file_path}") return pre_annotate.annotate_file(file_path) @router.get("/load_annotation") def load_annotation(scene: str, frame: str): """加载标注数据""" return scene_reader.read_annotations(scene, frame) @router.get("/load_ego_pose") def load_ego_pose(scene: str, frame: str): """加载自车姿态""" return scene_reader.read_ego_pose(scene, frame) @router.post("/loadworldlist") async def load_world_list(request: Request): """批量加载标注数据""" # 1. 获取原始请求体 (bytes) body = await request.body() # 2. 将 bytes 转换为字符串 body_str = body.decode('utf-8') try: items = json.loads(body_str) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="请求体不是有效的 JSON 格式") anns = [] for i in items: w= LoadWorldItem(**i) # todo 查询数据库的标注数据 anns.append({ "scene": w.scene, "frame": w.frame, "annotation": scene_reader.read_annotations(w.scene, w.frame) }) return anns @router.get("/datameta") def datameta(): """获取所有场景元数据""" return scene_reader.get_all_scenes() @router.get("/scenemeta") def scenemeta(scene: str): """获取单个场景元数据""" return scene_reader.get_one_scene(scene) @router.get("/get_all_scene_desc") def get_all_scene_desc(): """获取所有场景描述""" return scene_reader.get_all_scene_desc() @router.get("/objs_of_scene") def objs_of_scene(scene: str): """获取场景中的所有对象""" # todo从数据库查询图片列表 return get_all_objs(os.path.join("./data", scene)) # --- 辅助函数 --- def get_all_objs(path: str) -> List[Dict[str, Any]]: """从 CherryPy 类中提取的辅助函数""" label_folder = os.path.join(path, "label") if not os.path.isdir(label_folder): return [] files = [f for f in os.listdir(label_folder) if f.endswith(".json")] all_objs = {} for f in files: try: with open(os.path.join(path, "label", f)) as fd: boxes = json.load(fd) for b in boxes: o = {"category": b.get("obj_type"), "id": b.get("obj_id")} if not o["category"] or not o["id"]: continue k = f"{o['category']}-{o['id']}" if k in all_objs: all_objs[k]['count'] += 1 else: all_objs[k] = { "category": o["category"], "id": o["id"], "count": 1 } except (IOError, json.JSONDecodeError) as e: print(f"Error processing file {f}: {e}") return list(all_objs.values())