nbv_grasping/datasets/nbv_1/nbv_1_dataset.py
2024-10-09 16:13:22 +00:00

278 lines
15 KiB
Python
Executable File

import os
import re
import json
import pickle
import numpy as np
from PIL import Image
from torchvision import transforms
from configs.config import ConfigManager
from datasets.dataset import AdvancedDataset
from utils.omni_util import OmniUtil
from utils.pcl_util import PclUtil
from utils.pose_util import PoseUtil
class NextOneBestViewDataset(AdvancedDataset):
def __init__(self, dataset_config):
super(NextOneBestViewDataset, self).__init__(dataset_config)
self.data_type = dataset_config["data_type"]
self.source = dataset_config["source"]
self.gsnet_label_name = dataset_config["gsnet_label"]
#self.foundation_pose_label_name = dataset_config["foundation_pose_label"]
self.data_dir = ConfigManager.get("datasets", "general", "data_dir")
self.score_limit = ConfigManager.get("datasets", "general", "score_limit")
self.target_pts_num = ConfigManager.get("datasets", "general", "target_pts_num")
self.scene_pts_num = ConfigManager.get("datasets", "general", "scene_pts_num")
self.image_size = ConfigManager.get("datasets", "general", "image_size")
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")
self.canonical = ConfigManager.get("datasets", "general", "canonical")
self.small_batch_overfit = ConfigManager.get("settings", "experiment", "small_batch_overfit")
self.container_path = str(os.path.join(self.data_dir, self.source, "container_set.pickle"))
self.data_path = str(os.path.join(self.data_dir, self.source, self.data_type))
self.gsnet_label_path = str(os.path.join(self.data_dir, self.source, self.gsnet_label_name))
#self.foundation_pose_label_path = str(os.path.join(self.data_dir, self.source, self.foundation_pose_label_name))
self.scene_list = os.listdir(self.data_path)
self.task = ConfigManager.get("settings", "experiment", "task")
self.container_set = self.load_container_set()
self.data_list = self.get_datalist()
if self.small_batch_overfit:
small_batch_size = ConfigManager.get("settings", "experiment", "small_batch_size")
small_batch_times = ConfigManager.get("settings", "experiment", "small_batch_times")
self.data_list = self.data_list[:small_batch_size] * small_batch_times
self.transform = transforms.Compose([
transforms.Resize(self.image_size),
transforms.CenterCrop(int(self.image_size//14)*14),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.2)
])
def __len__(self):
return len(self.data_list)
def getitem(self, index) -> dict:
data_pair = self.data_list[index]
src_path = data_pair[0]["frame_path"]
dst_path = data_pair[1]["frame_path"]
target_name = data_pair[0]["object_name"]
scene_name = data_pair[0]["scene"]
src_data = self.load_src_data(src_path, target_name, canonical=self.canonical)
dst_data = self.load_dst_data(dst_path)
src_rot = src_data["cam_transform"][:3,:3]
dst_rot = dst_data["cam_transform"][:3,:3]
delta_rot = np.dot(dst_rot.T, src_rot)
delta_rot_6d = PoseUtil.matrix_to_rotation_6d_numpy(delta_rot)
item_data = {
"src_path": src_path,
"target_name": target_name,
"scene_name": scene_name,
"data_type": self.data_type,
"source": self.source,
"target_pts": src_data["target_pts"].astype(np.float32),
"scene_pts": src_data["scene_pts"].astype(np.float32),
"delta_rot_6d": delta_rot_6d.astype(np.float32),
"src_rot_mat": src_rot.astype(np.float32),
"dst_rot_mat": dst_rot.astype(np.float32),
"src_transform": src_data["cam_transform"].astype(np.float32),
"dst_transform": dst_data["cam_transform"].astype(np.float32),
}
# if self.rgb_feat_cache:
# item_data["rgb_feat"] = src_data["rgb_feat"].astype(np.float32)
# else:
# item_data["rgb"] = src_data["rgb"]
return item_data
def load_dst_data(self, frame_path):
""" camera params """
cam_transform = OmniUtil.get_transform_mat(frame_path)
frame_data = {'cam_transform': cam_transform}
return frame_data
def load_src_data(self, frame_path, target_object_name, canonical = False):
""" pts """
scene_pts = OmniUtil.get_points(path=frame_path, object_name=OmniUtil.FOREGROUND)
target_pts = OmniUtil.get_points(
path=frame_path, object_name=target_object_name
)
scene_pts = PclUtil.sample_pcl(scene_pts, self.scene_pts_num)
target_pts = PclUtil.sample_pcl(target_pts, self.target_pts_num)
""" camera params """
cam_transform = OmniUtil.get_transform_mat(frame_path)
if canonical:
target_pts = PclUtil.cam2canonical(target_pts, cam_transform)
scene_pts = PclUtil.cam2canonical(scene_pts, cam_transform)
frame_data = {
"target_pts": target_pts,
"scene_pts": scene_pts,
"cam_transform": cam_transform,
}
""" rgb """
# if self.rgb_feat_cache:
# rgb_feat = OmniUtil.get_rgb_feat(frame_path)
# frame_data["rgb_feat"] = rgb_feat
# else:
# rgb = OmniUtil.get_rgb(frame_path)
# rgb = Image.fromarray(rgb)
# rgb = self.transform(rgb)
# frame_data["rgb"] = rgb
return frame_data
def load_container_set(self):
container_list = ['chair_028', 'chair_029', 'chair_026', 'chair_027', 'table_025', 'table_027', 'table_026', 'table_028', 'sofa_014', 'sofa_013', 'picnic_basket_010', 'picnic_basket_011', 'cabinet_009', 'flower_pot_023', 'flower_pot_022', 'flower_pot_021', 'chair_017', 'chair_020', 'chair_012', 'chair_010', 'chair_018', 'chair_025', 'chair_024', 'chair_011', 'chair_001', 'chair_013', 'chair_004', 'chair_021', 'chair_023', 'chair_006', 'chair_014', 'chair_007', 'chair_003', 'chair_009', 'chair_022', 'chair_015', 'chair_016', 'chair_008', 'chair_005', 'chair_019', 'chair_002', 'table_004', 'table_023', 'table_014', 'table_024', 'table_019', 'table_022', 'table_007', 'table_017', 'table_013', 'table_002', 'table_016', 'table_009', 'table_008', 'table_003', 'table_015', 'table_001', 'table_018', 'table_005', 'table_020', 'table_021', 'sofa_001', 'sofa_005', 'sofa_012', 'sofa_009', 'sofa_006', 'sofa_008', 'sofa_011', 'sofa_004', 'sofa_003', 'sofa_002', 'sofa_007', 'sofa_010', 'picnic_basket_005', 'picnic_basket_004', 'picnic_basket_001', 'picnic_basket_008', 'picnic_basket_002', 'picnic_basket_009', 'picnic_basket_006', 'picnic_basket_003', 'picnic_basket_007', 'cabinet_006', 'cabinet_008', 'cabinet_002', 'cabinet_001', 'cabinet_005', 'cabinet_007', 'flower_pot_013', 'flower_pot_005', 'flower_pot_008', 'flower_pot_001', 'flower_pot_003', 'flower_pot_020', 'flower_pot_006', 'flower_pot_012', 'flower_pot_018', 'flower_pot_007', 'flower_pot_002', 'flower_pot_011', 'flower_pot_010', 'flower_pot_016', 'flower_pot_004', 'flower_pot_014', 'flower_pot_017', 'flower_pot_019']
container_set = set(container_list)
return container_set
def get_ground_object_set(self, scene_name):
fall_path = os.path.join(self.data_path, scene_name, "fall_objects.pickle")
with open(fall_path, 'rb') as f:
fall_objects = pickle.load(f)
return fall_objects
def get_datalist(self):
if self.task == "object_pose":
raise NotImplementedError("object_pose task is not supported now.")
#return self.get_foundation_pose_datalist()
elif self.task == "grasp_pose":
return self.get_grasp_pose_datalist()
else:
raise ValueError("task must be 'object_pose' or 'grasp_pose'.")
'''
def get_foundation_pose_datalist(self):
data_list = []
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
gsnet_label_scene_path = os.path.join(self.foundation_pose_label_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
target_object_set = self.get_target_object_set(scene)
unseen_object_set = self.get_target_object_set(scene)
cnt_under = 0
cnt_above = 0
limit = 0.002
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
score_label_path = os.path.join(
gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index)
)
with open(score_label_path, "r") as f:
score_label = json.load(f)
for obj_name in score_label.keys():
if obj_name in target_object_set:
scene_frame_list.append(
{
"frame_path": frame_path,
"object_name": obj_name,
"score": score_label[obj_name]["eval_result"]["ADD-S"],
"scene": scene
}
)
if score_label[obj_name]["eval_result"]["ADD-S"] <= limit:
cnt_under += 1
else:
cnt_above += 1
print(f"under {limit}: {cnt_under}, above {limit}: {cnt_above}")
for i in range(len(scene_frame_list)):
for j in range(i, len(scene_frame_list)):
fm_i, fm_j = scene_frame_list[i], scene_frame_list[j]
if fm_i["object_name"] == fm_j["object_name"]:
bad_view, good_view = None, None
if fm_i["score"] <= limit < fm_j["score"]:
good_view, bad_view = fm_i, fm_j
elif fm_i["score"] > limit >= fm_j["score"]:
good_view, bad_view = fm_j, fm_i
if bad_view is None or good_view is None:
continue
data_list.append((bad_view, good_view))
if bad_view["object_name"] in unseen_object_set:
unseen_object_set.remove(bad_view["object_name"])
return data_list
'''
def get_grasp_pose_datalist(self):
data_list = []
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
gsnet_label_scene_path = os.path.join(self.gsnet_label_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
ground_object_set = self.get_ground_object_set(scene)
unseen_object_set = set()
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
score_label_path = os.path.join(
gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index)
)
with open(score_label_path, "r") as f:
score_label = json.load(f)
for obj_name in score_label["avg_score"].keys():
if obj_name not in ground_object_set and obj_name not in self.container_set:
scene_frame_list.append(
{
"frame_path": frame_path,
"object_name": obj_name,
"score": score_label["avg_score"][obj_name],
"scene": scene
}
)
unseen_object_set.add(obj_name)
for i in range(len(scene_frame_list)):
for j in range(i, len(scene_frame_list)):
fm_i, fm_j = scene_frame_list[i], scene_frame_list[j]
if fm_i["object_name"] == fm_j["object_name"]:
bad_view, good_view = None, None
if fm_i["score"] <= self.score_limit < fm_j["score"]:
bad_view, good_view = fm_i, fm_j
elif fm_i["score"] > self.score_limit >= fm_j["score"]:
bad_view, good_view = fm_j, fm_i
if bad_view is None or good_view is None:
continue
sample_prob = ((max(0,good_view["score"] - bad_view["score"]))/0.3)**2
if np.random.rand() > sample_prob:
continue
data_list.append((bad_view, good_view))
if bad_view["object_name"] in unseen_object_set:
unseen_object_set.remove(bad_view["object_name"])
for obj_name in unseen_object_set:
views = []
for frame in scene_frame_list:
if frame["object_name"] == obj_name:
views.append(frame)
sorted_views = sorted(views, key=lambda x: x["score"], reverse=True)
total_view_num = len(sorted_views)
good_view_num = int(total_view_num * 0.1)
good_views = sorted_views[:good_view_num]
bad_views = sorted_views[good_view_num:]
filtered_good_view = []
filtered_bad_view = bad_views
for good_view in good_views:
if good_view["score"] >= 0.01:
filtered_good_view.append(good_view)
else:
filtered_bad_view.append(good_view)
for good_view in filtered_good_view:
for bad_view in filtered_bad_view:
data_list.append((bad_view, good_view))
return data_list