nbv_grasping/runners/preprocessors/object_pose/FoundationPose_preprocessor.py
2024-10-09 16:13:22 +00:00

185 lines
7.2 KiB
Python
Executable File

import os
import re
import sys
import numpy as np
import torch
import trimesh
from torch.utils.data import DataLoader
path = os.path.abspath(__file__)
for i in range(4):
path = os.path.dirname(path)
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
from utils.omni_util import OmniUtil
from utils.view_util import ViewUtil
from runners.preprocessors.object_pose.abstract_object_pose_preprocessor import ObjectPosePreprocessor
from configs.config import ConfigManager
from torch.utils.data import Dataset
class ObjectPoseInferenceDataset(Dataset):
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
RGB_TEMPLATE = "rgb_{}.png"
MASK_TEMPLATE = "semantic_segmentation_{}.png"
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
def __init__(
self,
source="nbv1",
data_type="sample",
data_dir="/mnt/h/AI/Datasets",
):
self.data_dir = data_dir
self.empty_frame = set()
self.data_path = str(os.path.join(self.data_dir, source, data_type))
self.scene_list = os.listdir(self.data_path)
self.data_list = self.get_datalist()
self.object_data_list = self.get_object_datalist()
self.object_name_list = list(self.object_data_list.keys())
self.mesh_dir_path = os.path.join(self.data_dir, source, "objects")
self.meshes = {}
self.load_all_meshes()
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
frame_path, target = self.data_list[index]
frame_data = self.load_frame_data(frame_path=frame_path, object_name=target)
return frame_data
def load_all_meshes(self):
object_name_list = os.listdir(self.mesh_dir_path)
for object_name in object_name_list:
mesh_path = os.path.join(self.mesh_dir_path, object_name, "Scan", "Simp.obj")
mesh = trimesh.load(mesh_path)
object_model_scale = [0.001, 0.001, 0.001]
mesh.apply_scale(object_model_scale)
self.meshes[object_name] = mesh
def get_datalist(self):
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
target_list = OmniUtil.get_object_list(frame_path)
for target in target_list:
scene_frame_list.append((frame_path,target))
if len(target_list) == 0:
self.empty_frame.add(frame_path)
return scene_frame_list
def get_object_datalist(self):
object_datalist = {}
for data_item in self.data_list:
frame_path, target = data_item
if target not in object_datalist:
object_datalist[target] = []
object_datalist[target].append(frame_path)
return object_datalist
def get_object_data_batch(self, object_name):
object_data_list = self.object_data_list[object_name]
batch_data = {"frame_path_list":[],
"rgb_batch":[],
"depth_batch":[],
"seg_batch":[],
"gt_pose_batch":[],
"K":None,
"mesh":None}
for frame_path in object_data_list:
frame_data = self.load_frame_data(frame_path, object_name)
batch_data["frame_path_list"].append(frame_path)
batch_data["rgb_batch"].append(frame_data["rgb"])
batch_data["depth_batch"].append(frame_data["depth"])
batch_data["seg_batch"].append(frame_data["seg"])
batch_data["gt_pose_batch"].append(frame_data["gt_pose"])
batch_data["K"] = frame_data["K"]
batch_data["mesh"] = frame_data["mesh"]
batch_data["rgb_batch"] = np.asarray(batch_data["rgb_batch"],dtype=np.uint8)
batch_data["depth_batch"] = np.asarray(batch_data["depth_batch"])
batch_data["seg_batch"] = np.asarray(batch_data["seg_batch"])
batch_data["gt_pose_batch"] = np.asarray(batch_data["gt_pose_batch"])
return batch_data
def load_frame_data(self, frame_path, object_name):
rgb = OmniUtil.get_rgb(frame_path)
depth = OmniUtil.get_depth(frame_path)
seg = OmniUtil.get_single_seg(frame_path, object_name)
K = OmniUtil.get_intrinsic_matrix(frame_path)
gt_obj_pose = OmniUtil.get_o2c_pose(frame_path, object_name)
ret_dict = {
"frame_path": frame_path,
"rgb": rgb.astype(np.float32),
"depth": depth.astype(np.float32),
"seg": seg,
"K": K.astype(np.float32),
"object_name": object_name,
"mesh": self.meshes[object_name],
"gt_pose": gt_obj_pose.astype(np.float32)
}
return ret_dict
class FoundationPosePreprocessor(ObjectPosePreprocessor):
def __init__(self, config_path):
super().__init__(config_path)
def run(self):
for dataset_config in self.dataset_list_config:
dataset = ObjectPoseInferenceDataset(
source=dataset_config["source"],
data_type=dataset_config["data_type"],
data_dir=dataset_config["data_dir"],
)
result = self.prediction(dataset)
self.save_processed_data(result, dataset_config)
def prediction(self, dataset):
final_result = {}
cnt = 0
for object_name in dataset.object_name_list:
cnt += 1
print(f"Processing object: {object_name} ({cnt}/{len(dataset.object_name_list)})")
object_data_batch = dataset.get_object_data_batch(object_name)
print(f"batch size of object {object_name}: {len(object_data_batch['frame_path_list'])}")
pose_batch, result_batch = ViewUtil.get_object_pose_batch(
object_data_batch["K"],
object_data_batch["mesh"],
object_data_batch["rgb_batch"],
object_data_batch["depth_batch"],
object_data_batch["seg_batch"],
object_data_batch["gt_pose_batch"],
self.web_server_config["port"]
)
for frame_path, pred_pose,gt_pose,result in zip(object_data_batch["frame_path_list"], pose_batch,object_data_batch["gt_pose_batch"],result_batch):
if frame_path not in final_result:
final_result[frame_path]={}
final_result[frame_path][object_name] = {"gt_pose":gt_pose.tolist(),"pred_pose":pred_pose.tolist(),"eval_result":result}
for frame_path in dataset.empty_frame:
final_result[frame_path] = {}
return final_result
if __name__ == "__main__":
config_path = os.path.join(PROJECT_ROOT, "configs/server_object_preprocess_config.yaml")
preprocessor = FoundationPosePreprocessor(config_path)
preprocessor.run()