You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
5.4 KiB
Python
146 lines
5.4 KiB
Python
from typing import List, Dict, Any
|
|
|
|
from sqlmodel import select, Session
|
|
|
|
from crud.config.db import engine
|
|
from crud.entity.models import SceneWorldItem, Scene
|
|
from crud.entity.scene_dto import SaveWorldItem
|
|
|
|
|
|
class SceneService:
|
|
@classmethod
|
|
def get_scene_info(cls,s=None):
|
|
result_list = []
|
|
with Session(engine) as session:
|
|
|
|
query_stmt = select(Scene)
|
|
if s is not None:
|
|
query_stmt = query_stmt.where(Scene.scene_name == s)
|
|
scene_all = session.exec(query_stmt)
|
|
for scene in scene_all:
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene.scene_name)
|
|
query_result_item = session.exec(query_stmt)
|
|
scene_worlds = query_result_item.all()
|
|
item = {
|
|
"scene": scene.scene_name,
|
|
"frames": [scene_world.frame for scene_world in scene_worlds],
|
|
"lidar_ext": scene.lidar_ext,
|
|
"camera_ext": scene.camera_ext,
|
|
"radar_ext": scene.radar_ext,
|
|
"aux_lidar_ext": scene.aux_lidar_ext,
|
|
"boxtype": scene.box_type,
|
|
"camera": [
|
|
"right",
|
|
"left",
|
|
"front"
|
|
],
|
|
"calib": scene.calib_json,
|
|
}
|
|
result_list.append(item)
|
|
return result_list
|
|
|
|
@classmethod
|
|
def save_world_list(cls, items: List[SaveWorldItem]):
|
|
"""批量保存标注数据"""
|
|
for item in items:
|
|
scene = item.scene
|
|
frame = item.frame
|
|
ann = item.annotation
|
|
with Session(engine) as session:
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene,
|
|
SceneWorldItem.frame == frame)
|
|
result_item = session.exec(query_stmt).first()
|
|
result_item.label_json = ann
|
|
session.add(result_item)
|
|
session.commit()
|
|
return "ok"
|
|
@classmethod
|
|
def update_label_json(cls, scene, frame,ann):
|
|
with Session(engine) as session:
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene,
|
|
SceneWorldItem.frame == frame)
|
|
result_item = session.exec(query_stmt).first()
|
|
result_item.label_json = ann
|
|
session.add(result_item)
|
|
session.commit()
|
|
return "ok"
|
|
|
|
@classmethod
|
|
def get_frame_ids(cls, scene: str):
|
|
|
|
with Session(engine) as session:
|
|
exec_result = session.exec(select(SceneWorldItem.id).where(SceneWorldItem.scene_name == scene))
|
|
ids = exec_result.all()
|
|
return ids
|
|
|
|
@classmethod
|
|
def get_scene_items(cls, scene: str):
|
|
with Session(engine) as session:
|
|
items = session.exec(select(SceneWorldItem).where(SceneWorldItem.scene_name == scene)).all()
|
|
return items
|
|
|
|
@classmethod
|
|
def get_label_json(cls, scene, frame):
|
|
with Session(engine) as session:
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame)
|
|
result_item = session.exec(query_stmt).first()
|
|
return result_item.label_json if result_item else []
|
|
|
|
@classmethod
|
|
def get_ego_pose_json(cls, scene, frame):
|
|
with Session(engine) as session:
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame)
|
|
result_item = session.exec(query_stmt).first()
|
|
return result_item.ego_pose_json if result_item else None
|
|
|
|
@classmethod
|
|
def get_scene_names(cls,s=None):
|
|
with Session(engine) as session:
|
|
query_stmt = select(Scene)
|
|
if s is not None:
|
|
query_stmt = query_stmt.where(Scene.scene_name == s)
|
|
query_result_item = session.exec(query_stmt)
|
|
return [i.scene_name for i in query_result_item.all()]
|
|
|
|
@classmethod
|
|
def get_scene_desc(cls):
|
|
with Session(engine) as session:
|
|
query_stmt = select(Scene)
|
|
query_result_item = session.exec(query_stmt)
|
|
result = {}
|
|
for i in query_result_item.all():
|
|
result[i.scene_name] = i.desc
|
|
return result
|
|
|
|
@classmethod
|
|
def get_all_objs(cls,scene: str) -> List[Dict[str, Any]]:
|
|
"""从 CherryPy 类中提取的辅助函数"""
|
|
scene_items:List[SceneWorldItem] = cls.get_scene_items(scene)
|
|
if scene_items is None:
|
|
return []
|
|
all_objs = {}
|
|
for f in scene_items:
|
|
boxes = f.label_json
|
|
for b in boxes:
|
|
o = {"category": b.get("obj_type"), "id": b.get("obj_id")}
|
|
if not o["category"] or not o["id"]: continue
|
|
|
|
k = f"{o['category']}-{o['id']}"
|
|
if k in all_objs:
|
|
all_objs[k]['count'] += 1
|
|
else:
|
|
all_objs[k] = {
|
|
"category": o["category"],
|
|
"id": o["id"],
|
|
"count": 1
|
|
}
|
|
|
|
|
|
return list(all_objs.values())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# print(SceneService.get_frame_ids("example"))
|
|
# print(SceneService.get_scene_items("example"))
|
|
print(SceneService.get_scene_info())
|