feat: rename table

main
吴柏豪 3 weeks ago
parent 8152086a7a
commit 12fcc7a3ac

@ -10,7 +10,7 @@ from crud.config.db import engine
def get_uuid():
return uuid.uuid1().hex
class Scene(SQLModel,table=True):
class PointScene(SQLModel, table=True):
id: str = Field(default_factory=get_uuid, max_length=32, primary_key=True)
scene_name: str = Field(max_length=256, description="事件名称", unique=True)
calib_json: List[dict] = Field(sa_type=JSON, nullable=True, description="calib是从点云到图像的校准矩阵。它是可选的但如果提供该框将投影在图像上以帮助注释。")
@ -22,7 +22,7 @@ class Scene(SQLModel,table=True):
desc: List[dict] = Field(sa_type=JSON, nullable=True, description="calib是从点云到图像的校准矩阵。它是可选的但如果提供该框将投影在图像上以帮助注释。")
class SceneWorldItem(SQLModel,table=True):
class PointSceneWorldItem(SQLModel, table=True):
__tablename__ = "scene_world_item"
id: str = Field(default_factory=get_uuid, max_length=32, primary_key=True)
scene_id: str = Field(max_length=256, description="事件id")

@ -3,7 +3,7 @@ from typing import List, Dict, Any
from sqlmodel import select, Session
from crud.config.db import engine
from crud.entity.models import SceneWorldItem, Scene
from crud.entity.models import PointSceneWorldItem, PointScene
from crud.entity.scene_dto import SaveWorldItem
@ -13,12 +13,12 @@ class SceneService:
result_list = []
with Session(engine) as session:
query_stmt = select(Scene)
query_stmt = select(PointScene)
if s is not None:
query_stmt = query_stmt.where(Scene.scene_name == s)
query_stmt = query_stmt.where(PointScene.scene_name == s)
scene_all = session.exec(query_stmt)
for scene in scene_all:
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene.scene_name)
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene.scene_name)
query_result_item = session.exec(query_stmt)
scene_worlds = query_result_item.all()
item = {
@ -47,8 +47,8 @@ class SceneService:
frame = item.frame
ann = item.annotation
with Session(engine) as session:
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene,
SceneWorldItem.frame == frame)
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene,
PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
result_item.label_json = ann
session.add(result_item)
@ -57,8 +57,8 @@ class SceneService:
@classmethod
def update_label_json(cls, scene, frame,ann):
with Session(engine) as session:
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene,
SceneWorldItem.frame == frame)
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene,
PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
result_item.label_json = ann
session.add(result_item)
@ -69,43 +69,43 @@ class SceneService:
def get_frame_ids(cls, scene: str):
with Session(engine) as session:
exec_result = session.exec(select(SceneWorldItem.id).where(SceneWorldItem.scene_name == scene))
exec_result = session.exec(select(PointSceneWorldItem.id).where(PointSceneWorldItem.scene_name == scene))
ids = exec_result.all()
return ids
@classmethod
def get_scene_items(cls, scene: str):
with Session(engine) as session:
items = session.exec(select(SceneWorldItem).where(SceneWorldItem.scene_name == scene)).all()
items = session.exec(select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene)).all()
return items
@classmethod
def get_label_json(cls, scene, frame):
with Session(engine) as session:
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame)
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
return result_item.label_json if result_item else []
@classmethod
def get_ego_pose_json(cls, scene, frame):
with Session(engine) as session:
query_stmt = select(SceneWorldItem).where(SceneWorldItem.scene_name == scene, SceneWorldItem.frame == frame)
query_stmt = select(PointSceneWorldItem).where(PointSceneWorldItem.scene_name == scene, PointSceneWorldItem.frame == frame)
result_item = session.exec(query_stmt).first()
return result_item.ego_pose_json if result_item else None
@classmethod
def get_scene_names(cls,s=None):
with Session(engine) as session:
query_stmt = select(Scene)
query_stmt = select(PointScene)
if s is not None:
query_stmt = query_stmt.where(Scene.scene_name == s)
query_stmt = query_stmt.where(PointScene.scene_name == s)
query_result_item = session.exec(query_stmt)
return [i.scene_name for i in query_result_item.all()]
@classmethod
def get_scene_desc(cls):
with Session(engine) as session:
query_stmt = select(Scene)
query_stmt = select(PointScene)
query_result_item = session.exec(query_stmt)
result = {}
for i in query_result_item.all():
@ -115,7 +115,7 @@ class SceneService:
@classmethod
def get_all_objs(cls,scene: str) -> List[Dict[str, Any]]:
"""从 CherryPy 类中提取的辅助函数"""
scene_items:List[SceneWorldItem] = cls.get_scene_items(scene)
scene_items:List[PointSceneWorldItem] = cls.get_scene_items(scene)
if scene_items is None:
return []
all_objs = {}

@ -2,7 +2,7 @@ from typing import List
import numpy as np
from crud.entity.models import SceneWorldItem
from crud.entity.models import PointSceneWorldItem
from crud.service.scene_service import SceneService
@ -44,7 +44,7 @@ class LabelChecker:
# files = os.listdir(label_folder)
labels = {}
obj_ids = {}
scene_items:List[SceneWorldItem] = SceneService.get_scene_items(self.path)
scene_items:List[PointSceneWorldItem] = SceneService.get_scene_items(self.path)
for s in scene_items:
l = s.label_json
frame_id = s.frame

Loading…
Cancel
Save