nbv_grasping/datasets/nbv_1/nbv_1_dataset.py

278 lines
15 KiB
Python
Raw Normal View History

2024-10-09 16:13:22 +00:00
import os
import re
import json
import pickle
import numpy as np
from PIL import Image
from torchvision import transforms
from configs.config import ConfigManager
from datasets.dataset import AdvancedDataset
from utils.omni_util import OmniUtil
from utils.pcl_util import PclUtil
from utils.pose_util import PoseUtil
class NextOneBestViewDataset(AdvancedDataset):
def __init__(self, dataset_config):
super(NextOneBestViewDataset, self).__init__(dataset_config)
self.data_type = dataset_config["data_type"]
self.source = dataset_config["source"]
self.gsnet_label_name = dataset_config["gsnet_label"]
#self.foundation_pose_label_name = dataset_config["foundation_pose_label"]
self.data_dir = ConfigManager.get("datasets", "general", "data_dir")
self.score_limit = ConfigManager.get("datasets", "general", "score_limit")
self.target_pts_num = ConfigManager.get("datasets", "general", "target_pts_num")
self.scene_pts_num = ConfigManager.get("datasets", "general", "scene_pts_num")
self.image_size = ConfigManager.get("datasets", "general", "image_size")
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")
self.canonical = ConfigManager.get("datasets", "general", "canonical")
self.small_batch_overfit = ConfigManager.get("settings", "experiment", "small_batch_overfit")
self.container_path = str(os.path.join(self.data_dir, self.source, "container_set.pickle"))
self.data_path = str(os.path.join(self.data_dir, self.source, self.data_type))
self.gsnet_label_path = str(os.path.join(self.data_dir, self.source, self.gsnet_label_name))
#self.foundation_pose_label_path = str(os.path.join(self.data_dir, self.source, self.foundation_pose_label_name))
self.scene_list = os.listdir(self.data_path)
self.task = ConfigManager.get("settings", "experiment", "task")
self.container_set = self.load_container_set()
self.data_list = self.get_datalist()
if self.small_batch_overfit:
small_batch_size = ConfigManager.get("settings", "experiment", "small_batch_size")
small_batch_times = ConfigManager.get("settings", "experiment", "small_batch_times")
self.data_list = self.data_list[:small_batch_size] * small_batch_times
self.transform = transforms.Compose([
transforms.Resize(self.image_size),
transforms.CenterCrop(int(self.image_size//14)*14),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.2)
])
def __len__(self):
return len(self.data_list)
def getitem(self, index) -> dict:
data_pair = self.data_list[index]
src_path = data_pair[0]["frame_path"]
dst_path = data_pair[1]["frame_path"]
target_name = data_pair[0]["object_name"]
scene_name = data_pair[0]["scene"]
src_data = self.load_src_data(src_path, target_name, canonical=self.canonical)
dst_data = self.load_dst_data(dst_path)
src_rot = src_data["cam_transform"][:3,:3]
dst_rot = dst_data["cam_transform"][:3,:3]
delta_rot = np.dot(dst_rot.T, src_rot)
delta_rot_6d = PoseUtil.matrix_to_rotation_6d_numpy(delta_rot)
item_data = {
"src_path": src_path,
"target_name": target_name,
"scene_name": scene_name,
"data_type": self.data_type,
"source": self.source,
"target_pts": src_data["target_pts"].astype(np.float32),
"scene_pts": src_data["scene_pts"].astype(np.float32),
"delta_rot_6d": delta_rot_6d.astype(np.float32),
"src_rot_mat": src_rot.astype(np.float32),
"dst_rot_mat": dst_rot.astype(np.float32),
"src_transform": src_data["cam_transform"].astype(np.float32),
"dst_transform": dst_data["cam_transform"].astype(np.float32),
}
# if self.rgb_feat_cache:
# item_data["rgb_feat"] = src_data["rgb_feat"].astype(np.float32)
# else:
# item_data["rgb"] = src_data["rgb"]
return item_data
def load_dst_data(self, frame_path):
""" camera params """
cam_transform = OmniUtil.get_transform_mat(frame_path)
frame_data = {'cam_transform': cam_transform}
return frame_data
def load_src_data(self, frame_path, target_object_name, canonical = False):
""" pts """
scene_pts = OmniUtil.get_points(path=frame_path, object_name=OmniUtil.FOREGROUND)
target_pts = OmniUtil.get_points(
path=frame_path, object_name=target_object_name
)
scene_pts = PclUtil.sample_pcl(scene_pts, self.scene_pts_num)
target_pts = PclUtil.sample_pcl(target_pts, self.target_pts_num)
""" camera params """
cam_transform = OmniUtil.get_transform_mat(frame_path)
if canonical:
target_pts = PclUtil.cam2canonical(target_pts, cam_transform)
scene_pts = PclUtil.cam2canonical(scene_pts, cam_transform)
frame_data = {
"target_pts": target_pts,
"scene_pts": scene_pts,
"cam_transform": cam_transform,
}
""" rgb """
# if self.rgb_feat_cache:
# rgb_feat = OmniUtil.get_rgb_feat(frame_path)
# frame_data["rgb_feat"] = rgb_feat
# else:
# rgb = OmniUtil.get_rgb(frame_path)
# rgb = Image.fromarray(rgb)
# rgb = self.transform(rgb)
# frame_data["rgb"] = rgb
return frame_data
def load_container_set(self):
container_list = ['chair_028', 'chair_029', 'chair_026', 'chair_027', 'table_025', 'table_027', 'table_026', 'table_028', 'sofa_014', 'sofa_013', 'picnic_basket_010', 'picnic_basket_011', 'cabinet_009', 'flower_pot_023', 'flower_pot_022', 'flower_pot_021', 'chair_017', 'chair_020', 'chair_012', 'chair_010', 'chair_018', 'chair_025', 'chair_024', 'chair_011', 'chair_001', 'chair_013', 'chair_004', 'chair_021', 'chair_023', 'chair_006', 'chair_014', 'chair_007', 'chair_003', 'chair_009', 'chair_022', 'chair_015', 'chair_016', 'chair_008', 'chair_005', 'chair_019', 'chair_002', 'table_004', 'table_023', 'table_014', 'table_024', 'table_019', 'table_022', 'table_007', 'table_017', 'table_013', 'table_002', 'table_016', 'table_009', 'table_008', 'table_003', 'table_015', 'table_001', 'table_018', 'table_005', 'table_020', 'table_021', 'sofa_001', 'sofa_005', 'sofa_012', 'sofa_009', 'sofa_006', 'sofa_008', 'sofa_011', 'sofa_004', 'sofa_003', 'sofa_002', 'sofa_007', 'sofa_010', 'picnic_basket_005', 'picnic_basket_004', 'picnic_basket_001', 'picnic_basket_008', 'picnic_basket_002', 'picnic_basket_009', 'picnic_basket_006', 'picnic_basket_003', 'picnic_basket_007', 'cabinet_006', 'cabinet_008', 'cabinet_002', 'cabinet_001', 'cabinet_005', 'cabinet_007', 'flower_pot_013', 'flower_pot_005', 'flower_pot_008', 'flower_pot_001', 'flower_pot_003', 'flower_pot_020', 'flower_pot_006', 'flower_pot_012', 'flower_pot_018', 'flower_pot_007', 'flower_pot_002', 'flower_pot_011', 'flower_pot_010', 'flower_pot_016', 'flower_pot_004', 'flower_pot_014', 'flower_pot_017', 'flower_pot_019']
container_set = set(container_list)
return container_set
def get_ground_object_set(self, scene_name):
fall_path = os.path.join(self.data_path, scene_name, "fall_objects.pickle")
with open(fall_path, 'rb') as f:
fall_objects = pickle.load(f)
return fall_objects
def get_datalist(self):
if self.task == "object_pose":
raise NotImplementedError("object_pose task is not supported now.")
#return self.get_foundation_pose_datalist()
elif self.task == "grasp_pose":
return self.get_grasp_pose_datalist()
else:
raise ValueError("task must be 'object_pose' or 'grasp_pose'.")
'''
def get_foundation_pose_datalist(self):
data_list = []
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
gsnet_label_scene_path = os.path.join(self.foundation_pose_label_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
target_object_set = self.get_target_object_set(scene)
unseen_object_set = self.get_target_object_set(scene)
cnt_under = 0
cnt_above = 0
limit = 0.002
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
score_label_path = os.path.join(
gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index)
)
with open(score_label_path, "r") as f:
score_label = json.load(f)
for obj_name in score_label.keys():
if obj_name in target_object_set:
scene_frame_list.append(
{
"frame_path": frame_path,
"object_name": obj_name,
"score": score_label[obj_name]["eval_result"]["ADD-S"],
"scene": scene
}
)
if score_label[obj_name]["eval_result"]["ADD-S"] <= limit:
cnt_under += 1
else:
cnt_above += 1
print(f"under {limit}: {cnt_under}, above {limit}: {cnt_above}")
for i in range(len(scene_frame_list)):
for j in range(i, len(scene_frame_list)):
fm_i, fm_j = scene_frame_list[i], scene_frame_list[j]
if fm_i["object_name"] == fm_j["object_name"]:
bad_view, good_view = None, None
if fm_i["score"] <= limit < fm_j["score"]:
good_view, bad_view = fm_i, fm_j
elif fm_i["score"] > limit >= fm_j["score"]:
good_view, bad_view = fm_j, fm_i
if bad_view is None or good_view is None:
continue
data_list.append((bad_view, good_view))
if bad_view["object_name"] in unseen_object_set:
unseen_object_set.remove(bad_view["object_name"])
return data_list
'''
def get_grasp_pose_datalist(self):
data_list = []
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
gsnet_label_scene_path = os.path.join(self.gsnet_label_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
ground_object_set = self.get_ground_object_set(scene)
unseen_object_set = set()
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
score_label_path = os.path.join(
gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index)
)
with open(score_label_path, "r") as f:
score_label = json.load(f)
for obj_name in score_label["avg_score"].keys():
if obj_name not in ground_object_set and obj_name not in self.container_set:
scene_frame_list.append(
{
"frame_path": frame_path,
"object_name": obj_name,
"score": score_label["avg_score"][obj_name],
"scene": scene
}
)
unseen_object_set.add(obj_name)
for i in range(len(scene_frame_list)):
for j in range(i, len(scene_frame_list)):
fm_i, fm_j = scene_frame_list[i], scene_frame_list[j]
if fm_i["object_name"] == fm_j["object_name"]:
bad_view, good_view = None, None
if fm_i["score"] <= self.score_limit < fm_j["score"]:
bad_view, good_view = fm_i, fm_j
elif fm_i["score"] > self.score_limit >= fm_j["score"]:
bad_view, good_view = fm_j, fm_i
if bad_view is None or good_view is None:
continue
sample_prob = ((max(0,good_view["score"] - bad_view["score"]))/0.3)**2
if np.random.rand() > sample_prob:
continue
data_list.append((bad_view, good_view))
if bad_view["object_name"] in unseen_object_set:
unseen_object_set.remove(bad_view["object_name"])
for obj_name in unseen_object_set:
views = []
for frame in scene_frame_list:
if frame["object_name"] == obj_name:
views.append(frame)
sorted_views = sorted(views, key=lambda x: x["score"], reverse=True)
total_view_num = len(sorted_views)
good_view_num = int(total_view_num * 0.1)
good_views = sorted_views[:good_view_num]
bad_views = sorted_views[good_view_num:]
filtered_good_view = []
filtered_bad_view = bad_views
for good_view in good_views:
if good_view["score"] >= 0.01:
filtered_good_view.append(good_view)
else:
filtered_bad_view.append(good_view)
for good_view in filtered_good_view:
for bad_view in filtered_bad_view:
data_list.append((bad_view, good_view))
return data_list