You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
5.6 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import os
from typing import List, Dict, Any
from fastapi import APIRouter, HTTPException,Request
from pydantic import BaseModel, Field
import scene_reader
from tools import check_labels as check
from algos import pre_annotate
class SaveWorldItem(BaseModel):
scene: str
frame: str
annotation: Any
class CropSceneRequest(BaseModel):
rawSceneId: str
startTime: str
seconds: str
desc: str
class PredictRotationRequest(BaseModel):
points: Any # 假设是 N*3 的点云数据
# 1. 定义模型
class LoadWorldItem(BaseModel):
scene: str = Field(description="这是scene")
frame: str = Field(description="这是frame")
router = APIRouter()
@router.post("/saveworldlist")
async def saveworldlist(items: List[SaveWorldItem]):
"""批量保存标注数据"""
for item in items:
scene = item.scene
frame = item.frame
ann = item.annotation
label_dir = os.path.join("./data", scene, "label")
os.makedirs(label_dir, exist_ok=True) # 确保目录存在
file_path = os.path.join(label_dir, f"{frame}.json")
# todo 标注存入数据库
with open(file_path, 'w') as f:
json.dump(ann, f, indent=2, sort_keys=True)
return "ok"
@router.post("/cropscene")
async def cropscene(request_data: CropSceneRequest):
"""裁剪场景"""
rawdata = request_data.rawSceneId
timestamp = rawdata.split("_")[0]
log_file = f"temp/crop-scene-{timestamp}.log"
# 注意os.system 存在安全风险,在生产环境中应替换为更安全的 subprocess 模块
cmd = (
f"python ./tools/dataset_preprocess/crop_scene.py generate "
f"{rawdata[0:10]}/{timestamp}_preprocessed/dataset_2hz - "
f"{request_data.startTime} {request_data.seconds} "
f'"{request_data.desc}" > {log_file} 2>&1'
)
print(f"Executing command: {cmd}")
code = os.system(cmd)
log = []
if os.path.exists(log_file):
with open(log_file) as f:
log = [s.strip() for s in f.readlines()]
os.remove(log_file)
return {"code": code, "log": log}
@router.get("/checkscene")
def checkscene(scene: str):
"""检查场景的标注"""
ck = check.LabelChecker(os.path.join("./data", scene))
ck.check()
print(ck.messages)
return ck.messages
@router.post("/predict_rotation")
async def predict_rotation(request_data: PredictRotationRequest):
"""预测旋转角度"""
# FastAPI 自动将请求体转换为 Pydantic 模型
return {"angle": pre_annotate.predict_yaw(request_data.points)}
@router.get("/auto_annotate")
def auto_annotate(scene: str, frame: str):
"""自动标注"""
print(f"Auto annotate {scene}, {frame}")
file_path = f'./data/{scene}/lidar/{frame}.pcd'
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
return pre_annotate.annotate_file(file_path)
@router.get("/load_annotation")
def load_annotation(scene: str, frame: str):
"""加载标注数据"""
return scene_reader.read_annotations(scene, frame)
@router.get("/load_ego_pose")
def load_ego_pose(scene: str, frame: str):
"""加载自车姿态"""
return scene_reader.read_ego_pose(scene, frame)
@router.post("/loadworldlist")
async def load_world_list(request: Request):
"""批量加载标注数据"""
# 1. 获取原始请求体 (bytes)
body = await request.body()
# 2. 将 bytes 转换为字符串
body_str = body.decode('utf-8')
try:
items = json.loads(body_str)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="请求体不是有效的 JSON 格式")
anns = []
for i in items:
w= LoadWorldItem(**i)
# todo 查询数据库的标注数据
anns.append({
"scene": w.scene,
"frame": w.frame,
"annotation": scene_reader.read_annotations(w.scene, w.frame)
})
return anns
@router.get("/datameta")
def datameta():
"""获取所有场景元数据"""
return scene_reader.get_all_scenes()
@router.get("/scenemeta")
def scenemeta(scene: str):
"""获取单个场景元数据"""
return scene_reader.get_one_scene(scene)
@router.get("/get_all_scene_desc")
def get_all_scene_desc():
"""获取所有场景描述"""
return scene_reader.get_all_scene_desc()
@router.get("/objs_of_scene")
def objs_of_scene(scene: str):
"""获取场景中的所有对象"""
# todo从数据库查询图片列表
return get_all_objs(os.path.join("./data", scene))
# --- 辅助函数 ---
def get_all_objs(path: str) -> List[Dict[str, Any]]:
"""从 CherryPy 类中提取的辅助函数"""
label_folder = os.path.join(path, "label")
if not os.path.isdir(label_folder):
return []
files = [f for f in os.listdir(label_folder) if f.endswith(".json")]
all_objs = {}
for f in files:
try:
with open(os.path.join(path, "label", f)) as fd:
boxes = json.load(fd)
for b in boxes:
o = {"category": b.get("obj_type"), "id": b.get("obj_id")}
if not o["category"] or not o["id"]: continue
k = f"{o['category']}-{o['id']}"
if k in all_objs:
all_objs[k]['count'] += 1
else:
all_objs[k] = {
"category": o["category"],
"id": o["id"],
"count": 1
}
except (IOError, json.JSONDecodeError) as e:
print(f"Error processing file {f}: {e}")
return list(all_objs.values())