|
|
|
|
@ -3,7 +3,7 @@ from typing import List, Dict, Any
|
|
|
|
|
from sqlmodel import select, Session
|
|
|
|
|
|
|
|
|
|
from crud.config.db import engine
|
|
|
|
|
from crud.entity.models import SceneWorldItem, Scene
|
|
|
|
|
from crud.entity.models import PointSceneWorldItem, PointScene
|
|
|
|
|
from crud.entity.scene_dto import SaveWorldItem
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -13,12 +13,12 @@ class SceneService:
|
|
|
|
|
result_list = []
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
|
|
|
|
|
query_stmt = select(Scene)
|
|
|
|
|
query_stmt = select(PointScene)
|
|
|
|
|
if s is not None:
|
|
|
|
|
query_stmt = query_stmt.where(Scene.scene_name == s)
|
|
|
|
|
query_stmt = query_stmt.where(PointScene.scene_name == s)
|
|
|
|
|
scene_all = session.exec(query_stmt)
|
|
|
|
|
for scene in scene_all:
|
|
|
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene.scene_name)
|
|
|
|
|
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene.scene_name)
|
|
|
|
|
query_result_item = session.exec(query_stmt)
|
|
|
|
|
scene_worlds = query_result_item.all()
|
|
|
|
|
item = {
|
|
|
|
|
@ -47,8 +47,8 @@ class SceneService:
|
|
|
|
|
frame = item.frame
|
|
|
|
|
ann = item.annotation
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene,
|
|
|
|
|
SceneWorldItem.frame == frame)
|
|
|
|
|
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene,
|
|
|
|
|
PointSceneWorldItem.frame == frame)
|
|
|
|
|
result_item = session.exec(query_stmt).first()
|
|
|
|
|
result_item.label_json = ann
|
|
|
|
|
session.add(result_item)
|
|
|
|
|
@ -57,8 +57,8 @@ class SceneService:
|
|
|
|
|
@classmethod
|
|
|
|
|
def update_label_json(cls, scene, frame,ann):
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene,
|
|
|
|
|
SceneWorldItem.frame == frame)
|
|
|
|
|
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene,
|
|
|
|
|
PointSceneWorldItem.frame == frame)
|
|
|
|
|
result_item = session.exec(query_stmt).first()
|
|
|
|
|
result_item.label_json = ann
|
|
|
|
|
session.add(result_item)
|
|
|
|
|
@ -69,43 +69,43 @@ class SceneService:
|
|
|
|
|
def get_frame_ids(cls, scene: str):
|
|
|
|
|
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
exec_result = session.exec(select(SceneWorldItem.id).where(SceneWorldItem.scene_name == scene))
|
|
|
|
|
exec_result = session.exec(select(PointSceneWorldItem.id).where(PointSceneWorldItem.scene_name == scene))
|
|
|
|
|
ids = exec_result.all()
|
|
|
|
|
return ids
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_scene_items(cls, scene: str):
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
items = session.exec(select(SceneWorldItem).where(SceneWorldItem.scene_name == scene)).all()
|
|
|
|
|
items = session.exec(select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene)).all()
|
|
|
|
|
return items
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_label_json(cls, scene, frame):
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame)
|
|
|
|
|
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame)
|
|
|
|
|
result_item = session.exec(query_stmt).first()
|
|
|
|
|
return result_item.label_json if result_item else []
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_ego_pose_json(cls, scene, frame):
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame)
|
|
|
|
|
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame)
|
|
|
|
|
result_item = session.exec(query_stmt).first()
|
|
|
|
|
return result_item.ego_pose_json if result_item else None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_scene_names(cls,s=None):
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
query_stmt = select(Scene)
|
|
|
|
|
query_stmt = select(PointScene)
|
|
|
|
|
if s is not None:
|
|
|
|
|
query_stmt = query_stmt.where(Scene.scene_name == s)
|
|
|
|
|
query_stmt = query_stmt.where(PointScene.scene_name == s)
|
|
|
|
|
query_result_item = session.exec(query_stmt)
|
|
|
|
|
return [i.scene_name for i in query_result_item.all()]
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_scene_desc(cls):
|
|
|
|
|
with Session(engine) as session:
|
|
|
|
|
query_stmt = select(Scene)
|
|
|
|
|
query_stmt = select(PointScene)
|
|
|
|
|
query_result_item = session.exec(query_stmt)
|
|
|
|
|
result = {}
|
|
|
|
|
for i in query_result_item.all():
|
|
|
|
|
@ -115,7 +115,7 @@ class SceneService:
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_all_objs(cls,scene: str) -> List[Dict[str, Any]]:
|
|
|
|
|
"""从 CherryPy 类中提取的辅助函数"""
|
|
|
|
|
scene_items:List[SceneWorldItem] = cls.get_scene_items(scene)
|
|
|
|
|
scene_items:List[PointSceneWorldItem] = cls.get_scene_items(scene)
|
|
|
|
|
if scene_items is None:
|
|
|
|
|
return []
|
|
|
|
|
all_objs = {}
|
|
|
|
|
|