import os import re import json import pickle import numpy as np from PIL import Image from torchvision import transforms from configs.config import ConfigManager from datasets.dataset import AdvancedDataset from utils.omni_util import OmniUtil from utils.pcl_util import PclUtil from utils.pose_util import PoseUtil class NextOneBestViewDataset(AdvancedDataset): def __init__(self, dataset_config): super(NextOneBestViewDataset, self).__init__(dataset_config) self.data_type = dataset_config["data_type"] self.source = dataset_config["source"] self.gsnet_label_name = dataset_config["gsnet_label"] #self.foundation_pose_label_name = dataset_config["foundation_pose_label"] self.data_dir = ConfigManager.get("datasets", "general", "data_dir") self.score_limit = ConfigManager.get("datasets", "general", "score_limit") self.target_pts_num = ConfigManager.get("datasets", "general", "target_pts_num") self.scene_pts_num = ConfigManager.get("datasets", "general", "scene_pts_num") self.image_size = ConfigManager.get("datasets", "general", "image_size") self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache") self.canonical = ConfigManager.get("datasets", "general", "canonical") self.small_batch_overfit = ConfigManager.get("settings", "experiment", "small_batch_overfit") self.container_path = str(os.path.join(self.data_dir, self.source, "container_set.pickle")) self.data_path = str(os.path.join(self.data_dir, self.source, self.data_type)) self.gsnet_label_path = str(os.path.join(self.data_dir, self.source, self.gsnet_label_name)) #self.foundation_pose_label_path = str(os.path.join(self.data_dir, self.source, self.foundation_pose_label_name)) self.scene_list = os.listdir(self.data_path) self.task = ConfigManager.get("settings", "experiment", "task") self.container_set = self.load_container_set() self.data_list = self.get_datalist() if self.small_batch_overfit: small_batch_size = ConfigManager.get("settings", "experiment", "small_batch_size") small_batch_times = ConfigManager.get("settings", "experiment", "small_batch_times") self.data_list = self.data_list[:small_batch_size] * small_batch_times self.transform = transforms.Compose([ transforms.Resize(self.image_size), transforms.CenterCrop(int(self.image_size//14)*14), transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.2) ]) def __len__(self): return len(self.data_list) def getitem(self, index) -> dict: data_pair = self.data_list[index] src_path = data_pair[0]["frame_path"] dst_path = data_pair[1]["frame_path"] target_name = data_pair[0]["object_name"] scene_name = data_pair[0]["scene"] src_data = self.load_src_data(src_path, target_name, canonical=self.canonical) dst_data = self.load_dst_data(dst_path) src_rot = src_data["cam_transform"][:3,:3] dst_rot = dst_data["cam_transform"][:3,:3] delta_rot = np.dot(dst_rot.T, src_rot) delta_rot_6d = PoseUtil.matrix_to_rotation_6d_numpy(delta_rot) item_data = { "src_path": src_path, "target_name": target_name, "scene_name": scene_name, "data_type": self.data_type, "source": self.source, "target_pts": src_data["target_pts"].astype(np.float32), "scene_pts": src_data["scene_pts"].astype(np.float32), "delta_rot_6d": delta_rot_6d.astype(np.float32), "src_rot_mat": src_rot.astype(np.float32), "dst_rot_mat": dst_rot.astype(np.float32), "src_transform": src_data["cam_transform"].astype(np.float32), "dst_transform": dst_data["cam_transform"].astype(np.float32), } # if self.rgb_feat_cache: # item_data["rgb_feat"] = src_data["rgb_feat"].astype(np.float32) # else: # item_data["rgb"] = src_data["rgb"] return item_data def load_dst_data(self, frame_path): """ camera params """ cam_transform = OmniUtil.get_transform_mat(frame_path) frame_data = {'cam_transform': cam_transform} return frame_data def load_src_data(self, frame_path, target_object_name, canonical = False): """ pts """ scene_pts = OmniUtil.get_points(path=frame_path, object_name=OmniUtil.FOREGROUND) target_pts = OmniUtil.get_points( path=frame_path, object_name=target_object_name ) scene_pts = PclUtil.sample_pcl(scene_pts, self.scene_pts_num) target_pts = PclUtil.sample_pcl(target_pts, self.target_pts_num) """ camera params """ cam_transform = OmniUtil.get_transform_mat(frame_path) if canonical: target_pts = PclUtil.cam2canonical(target_pts, cam_transform) scene_pts = PclUtil.cam2canonical(scene_pts, cam_transform) frame_data = { "target_pts": target_pts, "scene_pts": scene_pts, "cam_transform": cam_transform, } """ rgb """ # if self.rgb_feat_cache: # rgb_feat = OmniUtil.get_rgb_feat(frame_path) # frame_data["rgb_feat"] = rgb_feat # else: # rgb = OmniUtil.get_rgb(frame_path) # rgb = Image.fromarray(rgb) # rgb = self.transform(rgb) # frame_data["rgb"] = rgb return frame_data def load_container_set(self): container_list = ['chair_028', 'chair_029', 'chair_026', 'chair_027', 'table_025', 'table_027', 'table_026', 'table_028', 'sofa_014', 'sofa_013', 'picnic_basket_010', 'picnic_basket_011', 'cabinet_009', 'flower_pot_023', 'flower_pot_022', 'flower_pot_021', 'chair_017', 'chair_020', 'chair_012', 'chair_010', 'chair_018', 'chair_025', 'chair_024', 'chair_011', 'chair_001', 'chair_013', 'chair_004', 'chair_021', 'chair_023', 'chair_006', 'chair_014', 'chair_007', 'chair_003', 'chair_009', 'chair_022', 'chair_015', 'chair_016', 'chair_008', 'chair_005', 'chair_019', 'chair_002', 'table_004', 'table_023', 'table_014', 'table_024', 'table_019', 'table_022', 'table_007', 'table_017', 'table_013', 'table_002', 'table_016', 'table_009', 'table_008', 'table_003', 'table_015', 'table_001', 'table_018', 'table_005', 'table_020', 'table_021', 'sofa_001', 'sofa_005', 'sofa_012', 'sofa_009', 'sofa_006', 'sofa_008', 'sofa_011', 'sofa_004', 'sofa_003', 'sofa_002', 'sofa_007', 'sofa_010', 'picnic_basket_005', 'picnic_basket_004', 'picnic_basket_001', 'picnic_basket_008', 'picnic_basket_002', 'picnic_basket_009', 'picnic_basket_006', 'picnic_basket_003', 'picnic_basket_007', 'cabinet_006', 'cabinet_008', 'cabinet_002', 'cabinet_001', 'cabinet_005', 'cabinet_007', 'flower_pot_013', 'flower_pot_005', 'flower_pot_008', 'flower_pot_001', 'flower_pot_003', 'flower_pot_020', 'flower_pot_006', 'flower_pot_012', 'flower_pot_018', 'flower_pot_007', 'flower_pot_002', 'flower_pot_011', 'flower_pot_010', 'flower_pot_016', 'flower_pot_004', 'flower_pot_014', 'flower_pot_017', 'flower_pot_019'] container_set = set(container_list) return container_set def get_ground_object_set(self, scene_name): fall_path = os.path.join(self.data_path, scene_name, "fall_objects.pickle") with open(fall_path, 'rb') as f: fall_objects = pickle.load(f) return fall_objects def get_datalist(self): if self.task == "object_pose": raise NotImplementedError("object_pose task is not supported now.") #return self.get_foundation_pose_datalist() elif self.task == "grasp_pose": return self.get_grasp_pose_datalist() else: raise ValueError("task must be 'object_pose' or 'grasp_pose'.") ''' def get_foundation_pose_datalist(self): data_list = [] for scene in self.scene_list: scene_path = os.path.join(self.data_path, scene) gsnet_label_scene_path = os.path.join(self.foundation_pose_label_path, scene) file_list = os.listdir(scene_path) scene_frame_list = [] target_object_set = self.get_target_object_set(scene) unseen_object_set = self.get_target_object_set(scene) cnt_under = 0 cnt_above = 0 limit = 0.002 for file in file_list: if file.startswith("camera_params"): frame_index = re.findall(r"\d+", file)[0] frame_path = os.path.join(scene_path, frame_index) score_label_path = os.path.join( gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index) ) with open(score_label_path, "r") as f: score_label = json.load(f) for obj_name in score_label.keys(): if obj_name in target_object_set: scene_frame_list.append( { "frame_path": frame_path, "object_name": obj_name, "score": score_label[obj_name]["eval_result"]["ADD-S"], "scene": scene } ) if score_label[obj_name]["eval_result"]["ADD-S"] <= limit: cnt_under += 1 else: cnt_above += 1 print(f"under {limit}: {cnt_under}, above {limit}: {cnt_above}") for i in range(len(scene_frame_list)): for j in range(i, len(scene_frame_list)): fm_i, fm_j = scene_frame_list[i], scene_frame_list[j] if fm_i["object_name"] == fm_j["object_name"]: bad_view, good_view = None, None if fm_i["score"] <= limit < fm_j["score"]: good_view, bad_view = fm_i, fm_j elif fm_i["score"] > limit >= fm_j["score"]: good_view, bad_view = fm_j, fm_i if bad_view is None or good_view is None: continue data_list.append((bad_view, good_view)) if bad_view["object_name"] in unseen_object_set: unseen_object_set.remove(bad_view["object_name"]) return data_list ''' def get_grasp_pose_datalist(self): data_list = [] for scene in self.scene_list: scene_path = os.path.join(self.data_path, scene) gsnet_label_scene_path = os.path.join(self.gsnet_label_path, scene) file_list = os.listdir(scene_path) scene_frame_list = [] ground_object_set = self.get_ground_object_set(scene) unseen_object_set = set() for file in file_list: if file.startswith("camera_params"): frame_index = re.findall(r"\d+", file)[0] frame_path = os.path.join(scene_path, frame_index) score_label_path = os.path.join( gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index) ) with open(score_label_path, "r") as f: score_label = json.load(f) for obj_name in score_label["avg_score"].keys(): if obj_name not in ground_object_set and obj_name not in self.container_set: scene_frame_list.append( { "frame_path": frame_path, "object_name": obj_name, "score": score_label["avg_score"][obj_name], "scene": scene } ) unseen_object_set.add(obj_name) for i in range(len(scene_frame_list)): for j in range(i, len(scene_frame_list)): fm_i, fm_j = scene_frame_list[i], scene_frame_list[j] if fm_i["object_name"] == fm_j["object_name"]: bad_view, good_view = None, None if fm_i["score"] <= self.score_limit < fm_j["score"]: bad_view, good_view = fm_i, fm_j elif fm_i["score"] > self.score_limit >= fm_j["score"]: bad_view, good_view = fm_j, fm_i if bad_view is None or good_view is None: continue sample_prob = ((max(0,good_view["score"] - bad_view["score"]))/0.3)**2 if np.random.rand() > sample_prob: continue data_list.append((bad_view, good_view)) if bad_view["object_name"] in unseen_object_set: unseen_object_set.remove(bad_view["object_name"]) for obj_name in unseen_object_set: views = [] for frame in scene_frame_list: if frame["object_name"] == obj_name: views.append(frame) sorted_views = sorted(views, key=lambda x: x["score"], reverse=True) total_view_num = len(sorted_views) good_view_num = int(total_view_num * 0.1) good_views = sorted_views[:good_view_num] bad_views = sorted_views[good_view_num:] filtered_good_view = [] filtered_bad_view = bad_views for good_view in good_views: if good_view["score"] >= 0.01: filtered_good_view.append(good_view) else: filtered_bad_view.append(good_view) for good_view in filtered_good_view: for bad_view in filtered_bad_view: data_list.append((bad_view, good_view)) return data_list