add multi seq training

This commit is contained in:
2024-09-23 14:30:51 +08:00
parent 6cdff9c83f
commit 3c4077ec4f
12 changed files with 152 additions and 93 deletions

View File

@@ -3,6 +3,7 @@ import numpy as np
import json
import cv2
import trimesh
import torch
from utils.pts import PtsUtil
class DataLoadUtil:
@@ -13,8 +14,21 @@ class DataLoadUtil:
return path
@staticmethod
def get_label_path(root, scene_name):
path = os.path.join(root,scene_name, f"label.json")
def get_label_num(root, scene_name):
label_dir = os.path.join(root,scene_name,"label")
return len(os.listdir(label_dir))
@staticmethod
def get_label_path(root, scene_name, seq_idx):
label_dir = os.path.join(root,scene_name,"label")
if not os.path.exists(label_dir):
os.makedirs(label_dir)
path = os.path.join(label_dir,f"{seq_idx}.json")
return path
@staticmethod
def get_label_path_old(root, scene_name):
path = os.path.join(root,scene_name,"label.json")
return path
@staticmethod
@@ -45,11 +59,14 @@ class DataLoadUtil:
mesh.export(model_path)
@staticmethod
def save_target_mesh_at_world_space(root, model_dir, scene_name):
def save_target_mesh_at_world_space(root, model_dir, scene_name, display_table_as_world_space_origin=True):
scene_info = DataLoadUtil.load_scene_info(root, scene_name)
target_name = scene_info["target_name"]
transformation = scene_info[target_name]
location = transformation["location"]
if display_table_as_world_space_origin:
location = transformation["location"] - DataLoadUtil.DISPLAY_TABLE_POSITION
else:
location = transformation["location"]
rotation_euler = transformation["rotation_euler"]
pose_mat = trimesh.transformations.euler_matrix(*rotation_euler)
pose_mat[:3, 3] = location
@@ -181,7 +198,9 @@ class DataLoadUtil:
@staticmethod
def get_real_cam_O_from_cam_L(cam_L, cam_O_to_cam_L, display_table_as_world_space_origin=True):
nO_to_display_table_pose = cam_L.cpu().numpy() @ cam_O_to_cam_L
if isinstance(cam_L, torch.Tensor):
cam_L = cam_L.cpu().numpy()
nO_to_display_table_pose = cam_L @ cam_O_to_cam_L
if display_table_as_world_space_origin:
display_table_to_world = np.eye(4)
display_table_to_world[:3, 3] = DataLoadUtil.DISPLAY_TABLE_POSITION