import os import re import sys import numpy as np import torch import trimesh from torch.utils.data import DataLoader path = os.path.abspath(__file__) for i in range(4): path = os.path.dirname(path) PROJECT_ROOT = path sys.path.append(PROJECT_ROOT) from utils.omni_util import OmniUtil from utils.view_util import ViewUtil from runners.preprocessors.object_pose.abstract_object_pose_preprocessor import ObjectPosePreprocessor from configs.config import ConfigManager from torch.utils.data import Dataset class ObjectPoseInferenceDataset(Dataset): CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json" DISTANCE_TEMPLATE = "distance_to_camera_{}.npy" RGB_TEMPLATE = "rgb_{}.png" MASK_TEMPLATE = "semantic_segmentation_{}.png" MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json" def __init__( self, source="nbv1", data_type="sample", data_dir="/mnt/h/AI/Datasets", ): self.data_dir = data_dir self.empty_frame = set() self.data_path = str(os.path.join(self.data_dir, source, data_type)) self.scene_list = os.listdir(self.data_path) self.data_list = self.get_datalist() self.object_data_list = self.get_object_datalist() self.object_name_list = list(self.object_data_list.keys()) self.mesh_dir_path = os.path.join(self.data_dir, source, "objects") self.meshes = {} self.load_all_meshes() def __len__(self): return len(self.data_list) def __getitem__(self, index): frame_path, target = self.data_list[index] frame_data = self.load_frame_data(frame_path=frame_path, object_name=target) return frame_data def load_all_meshes(self): object_name_list = os.listdir(self.mesh_dir_path) for object_name in object_name_list: mesh_path = os.path.join(self.mesh_dir_path, object_name, "Scan", "Simp.obj") mesh = trimesh.load(mesh_path) object_model_scale = [0.001, 0.001, 0.001] mesh.apply_scale(object_model_scale) self.meshes[object_name] = mesh def get_datalist(self): for scene in self.scene_list: scene_path = os.path.join(self.data_path, scene) file_list = os.listdir(scene_path) scene_frame_list = [] for file in file_list: if file.startswith("camera_params"): frame_index = re.findall(r"\d+", file)[0] frame_path = os.path.join(scene_path, frame_index) target_list = OmniUtil.get_object_list(frame_path) for target in target_list: scene_frame_list.append((frame_path,target)) if len(target_list) == 0: self.empty_frame.add(frame_path) return scene_frame_list def get_object_datalist(self): object_datalist = {} for data_item in self.data_list: frame_path, target = data_item if target not in object_datalist: object_datalist[target] = [] object_datalist[target].append(frame_path) return object_datalist def get_object_data_batch(self, object_name): object_data_list = self.object_data_list[object_name] batch_data = {"frame_path_list":[], "rgb_batch":[], "depth_batch":[], "seg_batch":[], "gt_pose_batch":[], "K":None, "mesh":None} for frame_path in object_data_list: frame_data = self.load_frame_data(frame_path, object_name) batch_data["frame_path_list"].append(frame_path) batch_data["rgb_batch"].append(frame_data["rgb"]) batch_data["depth_batch"].append(frame_data["depth"]) batch_data["seg_batch"].append(frame_data["seg"]) batch_data["gt_pose_batch"].append(frame_data["gt_pose"]) batch_data["K"] = frame_data["K"] batch_data["mesh"] = frame_data["mesh"] batch_data["rgb_batch"] = np.asarray(batch_data["rgb_batch"],dtype=np.uint8) batch_data["depth_batch"] = np.asarray(batch_data["depth_batch"]) batch_data["seg_batch"] = np.asarray(batch_data["seg_batch"]) batch_data["gt_pose_batch"] = np.asarray(batch_data["gt_pose_batch"]) return batch_data def load_frame_data(self, frame_path, object_name): rgb = OmniUtil.get_rgb(frame_path) depth = OmniUtil.get_depth(frame_path) seg = OmniUtil.get_single_seg(frame_path, object_name) K = OmniUtil.get_intrinsic_matrix(frame_path) gt_obj_pose = OmniUtil.get_o2c_pose(frame_path, object_name) ret_dict = { "frame_path": frame_path, "rgb": rgb.astype(np.float32), "depth": depth.astype(np.float32), "seg": seg, "K": K.astype(np.float32), "object_name": object_name, "mesh": self.meshes[object_name], "gt_pose": gt_obj_pose.astype(np.float32) } return ret_dict class FoundationPosePreprocessor(ObjectPosePreprocessor): def __init__(self, config_path): super().__init__(config_path) def run(self): for dataset_config in self.dataset_list_config: dataset = ObjectPoseInferenceDataset( source=dataset_config["source"], data_type=dataset_config["data_type"], data_dir=dataset_config["data_dir"], ) result = self.prediction(dataset) self.save_processed_data(result, dataset_config) def prediction(self, dataset): final_result = {} cnt = 0 for object_name in dataset.object_name_list: cnt += 1 print(f"Processing object: {object_name} ({cnt}/{len(dataset.object_name_list)})") object_data_batch = dataset.get_object_data_batch(object_name) print(f"batch size of object {object_name}: {len(object_data_batch['frame_path_list'])}") pose_batch, result_batch = ViewUtil.get_object_pose_batch( object_data_batch["K"], object_data_batch["mesh"], object_data_batch["rgb_batch"], object_data_batch["depth_batch"], object_data_batch["seg_batch"], object_data_batch["gt_pose_batch"], self.web_server_config["port"] ) for frame_path, pred_pose,gt_pose,result in zip(object_data_batch["frame_path_list"], pose_batch,object_data_batch["gt_pose_batch"],result_batch): if frame_path not in final_result: final_result[frame_path]={} final_result[frame_path][object_name] = {"gt_pose":gt_pose.tolist(),"pred_pose":pred_pose.tolist(),"eval_result":result} for frame_path in dataset.empty_frame: final_result[frame_path] = {} return final_result if __name__ == "__main__": config_path = os.path.join(PROJECT_ROOT, "configs/server_object_preprocess_config.yaml") preprocessor = FoundationPosePreprocessor(config_path) preprocessor.run()