You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

146 lines
5.6 KiB
Python

from typing import List, Dict, Any
from sqlmodel import select, Session
from crud.config.db import engine
from crud.entity.models import PointSceneWorldItem, PointScene
from crud.entity.scene_dto import SaveWorldItem
class SceneService:
@classmethod
def get_scene_info(cls,s=None):
result_list = []
with Session(engine) as session:
query_stmt = select(PointScene)
if s is not None:
query_stmt = query_stmt.where(PointScene.scene_name == s)
scene_all = session.exec(query_stmt)
for scene in scene_all:
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene.scene_name)
query_result_item = session.exec(query_stmt)
scene_worlds = query_result_item.all()
item = {
"scene": scene.scene_name,
"frames": [scene_world.frame for scene_world in scene_worlds],
"lidar_ext": scene.lidar_ext,
"camera_ext": scene.camera_ext,
"radar_ext": scene.radar_ext,
"aux_lidar_ext": scene.aux_lidar_ext,
"boxtype": scene.box_type,
"camera": [
"right",
"left",
"front"
],
"calib": scene.calib_json,
}
result_list.append(item)
return result_list
@classmethod
def save_world_list(cls, items: List[SaveWorldItem]):
"""批量保存标注数据"""
for item in items:
scene = item.scene
frame = item.frame
ann = item.annotation
with Session(engine) as session:
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene,
PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
result_item.label_json = ann
session.add(result_item)
session.commit()
return "ok"
@classmethod
def update_label_json(cls, scene, frame,ann):
with Session(engine) as session:
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene,
PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
result_item.label_json = ann
session.add(result_item)
session.commit()
return "ok"
@classmethod
def get_frame_ids(cls, scene: str):
with Session(engine) as session:
exec_result = session.exec(select(PointSceneWorldItem.id).where(PointSceneWorldItem.scene_name == scene))
ids = exec_result.all()
return ids
@classmethod
def get_scene_items(cls, scene: str):
with Session(engine) as session:
items = session.exec(select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene)).all()
return items
@classmethod
def get_label_json(cls, scene, frame):
with Session(engine) as session:
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
return result_item.label_json if result_item else []
@classmethod
def get_ego_pose_json(cls, scene, frame):
with Session(engine) as session:
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
return result_item.ego_pose_json if result_item else None
@classmethod
def get_scene_names(cls,s=None):
with Session(engine) as session:
query_stmt = select(PointScene)
if s is not None:
query_stmt = query_stmt.where(PointScene.scene_name == s)
query_result_item = session.exec(query_stmt)
return [i.scene_name for i in query_result_item.all()]
@classmethod
def get_scene_desc(cls):
with Session(engine) as session:
query_stmt = select(PointScene)
query_result_item = session.exec(query_stmt)
result = {}
for i in query_result_item.all():
result[i.scene_name] = i.desc
return result
@classmethod
def get_all_objs(cls,scene: str) -> List[Dict[str, Any]]:
"""从 CherryPy 类中提取的辅助函数"""
scene_items:List[PointSceneWorldItem] = cls.get_scene_items(scene)
if scene_items is None:
return []
all_objs = {}
for f in scene_items:
boxes = f.label_json
for b in boxes:
o = {"category": b.get("obj_type"), "id": b.get("obj_id")}
if not o["category"] or not o["id"]: continue
k = f"{o['category']}-{o['id']}"
if k in all_objs:
all_objs[k]['count'] += 1
else:
all_objs[k] = {
"category": o["category"],
"id": o["id"],
"count": 1
}
return list(all_objs.values())
if __name__ == '__main__':
# print(SceneService.get_frame_ids("example"))
# print(SceneService.get_scene_items("example"))
print(SceneService.get_scene_info())