185 lines
7.2 KiB
Python
185 lines
7.2 KiB
Python
|
import os
|
||
|
import re
|
||
|
import sys
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import trimesh
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
|
||
|
path = os.path.abspath(__file__)
|
||
|
for i in range(4):
|
||
|
path = os.path.dirname(path)
|
||
|
PROJECT_ROOT = path
|
||
|
sys.path.append(PROJECT_ROOT)
|
||
|
|
||
|
from utils.omni_util import OmniUtil
|
||
|
from utils.view_util import ViewUtil
|
||
|
from runners.preprocessors.object_pose.abstract_object_pose_preprocessor import ObjectPosePreprocessor
|
||
|
from configs.config import ConfigManager
|
||
|
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
|
||
|
class ObjectPoseInferenceDataset(Dataset):
|
||
|
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
|
||
|
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
|
||
|
RGB_TEMPLATE = "rgb_{}.png"
|
||
|
MASK_TEMPLATE = "semantic_segmentation_{}.png"
|
||
|
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
source="nbv1",
|
||
|
data_type="sample",
|
||
|
data_dir="/mnt/h/AI/Datasets",
|
||
|
):
|
||
|
|
||
|
self.data_dir = data_dir
|
||
|
self.empty_frame = set()
|
||
|
self.data_path = str(os.path.join(self.data_dir, source, data_type))
|
||
|
self.scene_list = os.listdir(self.data_path)
|
||
|
self.data_list = self.get_datalist()
|
||
|
|
||
|
self.object_data_list = self.get_object_datalist()
|
||
|
self.object_name_list = list(self.object_data_list.keys())
|
||
|
self.mesh_dir_path = os.path.join(self.data_dir, source, "objects")
|
||
|
|
||
|
self.meshes = {}
|
||
|
self.load_all_meshes()
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.data_list)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
frame_path, target = self.data_list[index]
|
||
|
frame_data = self.load_frame_data(frame_path=frame_path, object_name=target)
|
||
|
return frame_data
|
||
|
|
||
|
def load_all_meshes(self):
|
||
|
object_name_list = os.listdir(self.mesh_dir_path)
|
||
|
for object_name in object_name_list:
|
||
|
mesh_path = os.path.join(self.mesh_dir_path, object_name, "Scan", "Simp.obj")
|
||
|
mesh = trimesh.load(mesh_path)
|
||
|
object_model_scale = [0.001, 0.001, 0.001]
|
||
|
mesh.apply_scale(object_model_scale)
|
||
|
self.meshes[object_name] = mesh
|
||
|
|
||
|
def get_datalist(self):
|
||
|
for scene in self.scene_list:
|
||
|
scene_path = os.path.join(self.data_path, scene)
|
||
|
file_list = os.listdir(scene_path)
|
||
|
scene_frame_list = []
|
||
|
for file in file_list:
|
||
|
if file.startswith("camera_params"):
|
||
|
frame_index = re.findall(r"\d+", file)[0]
|
||
|
frame_path = os.path.join(scene_path, frame_index)
|
||
|
target_list = OmniUtil.get_object_list(frame_path)
|
||
|
for target in target_list:
|
||
|
scene_frame_list.append((frame_path,target))
|
||
|
if len(target_list) == 0:
|
||
|
self.empty_frame.add(frame_path)
|
||
|
|
||
|
return scene_frame_list
|
||
|
|
||
|
def get_object_datalist(self):
|
||
|
object_datalist = {}
|
||
|
for data_item in self.data_list:
|
||
|
frame_path, target = data_item
|
||
|
if target not in object_datalist:
|
||
|
object_datalist[target] = []
|
||
|
object_datalist[target].append(frame_path)
|
||
|
return object_datalist
|
||
|
|
||
|
def get_object_data_batch(self, object_name):
|
||
|
object_data_list = self.object_data_list[object_name]
|
||
|
batch_data = {"frame_path_list":[],
|
||
|
"rgb_batch":[],
|
||
|
"depth_batch":[],
|
||
|
"seg_batch":[],
|
||
|
"gt_pose_batch":[],
|
||
|
"K":None,
|
||
|
"mesh":None}
|
||
|
for frame_path in object_data_list:
|
||
|
frame_data = self.load_frame_data(frame_path, object_name)
|
||
|
batch_data["frame_path_list"].append(frame_path)
|
||
|
batch_data["rgb_batch"].append(frame_data["rgb"])
|
||
|
batch_data["depth_batch"].append(frame_data["depth"])
|
||
|
batch_data["seg_batch"].append(frame_data["seg"])
|
||
|
batch_data["gt_pose_batch"].append(frame_data["gt_pose"])
|
||
|
batch_data["K"] = frame_data["K"]
|
||
|
batch_data["mesh"] = frame_data["mesh"]
|
||
|
|
||
|
batch_data["rgb_batch"] = np.asarray(batch_data["rgb_batch"],dtype=np.uint8)
|
||
|
batch_data["depth_batch"] = np.asarray(batch_data["depth_batch"])
|
||
|
batch_data["seg_batch"] = np.asarray(batch_data["seg_batch"])
|
||
|
batch_data["gt_pose_batch"] = np.asarray(batch_data["gt_pose_batch"])
|
||
|
return batch_data
|
||
|
|
||
|
def load_frame_data(self, frame_path, object_name):
|
||
|
rgb = OmniUtil.get_rgb(frame_path)
|
||
|
depth = OmniUtil.get_depth(frame_path)
|
||
|
seg = OmniUtil.get_single_seg(frame_path, object_name)
|
||
|
K = OmniUtil.get_intrinsic_matrix(frame_path)
|
||
|
gt_obj_pose = OmniUtil.get_o2c_pose(frame_path, object_name)
|
||
|
ret_dict = {
|
||
|
"frame_path": frame_path,
|
||
|
"rgb": rgb.astype(np.float32),
|
||
|
"depth": depth.astype(np.float32),
|
||
|
"seg": seg,
|
||
|
"K": K.astype(np.float32),
|
||
|
"object_name": object_name,
|
||
|
"mesh": self.meshes[object_name],
|
||
|
"gt_pose": gt_obj_pose.astype(np.float32)
|
||
|
}
|
||
|
return ret_dict
|
||
|
|
||
|
class FoundationPosePreprocessor(ObjectPosePreprocessor):
|
||
|
|
||
|
def __init__(self, config_path):
|
||
|
super().__init__(config_path)
|
||
|
|
||
|
def run(self):
|
||
|
for dataset_config in self.dataset_list_config:
|
||
|
dataset = ObjectPoseInferenceDataset(
|
||
|
source=dataset_config["source"],
|
||
|
data_type=dataset_config["data_type"],
|
||
|
data_dir=dataset_config["data_dir"],
|
||
|
)
|
||
|
result = self.prediction(dataset)
|
||
|
self.save_processed_data(result, dataset_config)
|
||
|
|
||
|
def prediction(self, dataset):
|
||
|
final_result = {}
|
||
|
cnt = 0
|
||
|
for object_name in dataset.object_name_list:
|
||
|
cnt += 1
|
||
|
print(f"Processing object: {object_name} ({cnt}/{len(dataset.object_name_list)})")
|
||
|
object_data_batch = dataset.get_object_data_batch(object_name)
|
||
|
print(f"batch size of object {object_name}: {len(object_data_batch['frame_path_list'])}")
|
||
|
pose_batch, result_batch = ViewUtil.get_object_pose_batch(
|
||
|
object_data_batch["K"],
|
||
|
object_data_batch["mesh"],
|
||
|
object_data_batch["rgb_batch"],
|
||
|
object_data_batch["depth_batch"],
|
||
|
object_data_batch["seg_batch"],
|
||
|
object_data_batch["gt_pose_batch"],
|
||
|
self.web_server_config["port"]
|
||
|
)
|
||
|
for frame_path, pred_pose,gt_pose,result in zip(object_data_batch["frame_path_list"], pose_batch,object_data_batch["gt_pose_batch"],result_batch):
|
||
|
if frame_path not in final_result:
|
||
|
final_result[frame_path]={}
|
||
|
final_result[frame_path][object_name] = {"gt_pose":gt_pose.tolist(),"pred_pose":pred_pose.tolist(),"eval_result":result}
|
||
|
for frame_path in dataset.empty_frame:
|
||
|
final_result[frame_path] = {}
|
||
|
return final_result
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
config_path = os.path.join(PROJECT_ROOT, "configs/server_object_preprocess_config.yaml")
|
||
|
preprocessor = FoundationPosePreprocessor(config_path)
|
||
|
preprocessor.run()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|