278 lines
15 KiB
Python
Executable File
278 lines
15 KiB
Python
Executable File
import os
|
|
import re
|
|
import json
|
|
import pickle
|
|
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
from configs.config import ConfigManager
|
|
from datasets.dataset import AdvancedDataset
|
|
from utils.omni_util import OmniUtil
|
|
from utils.pcl_util import PclUtil
|
|
from utils.pose_util import PoseUtil
|
|
|
|
|
|
class NextOneBestViewDataset(AdvancedDataset):
|
|
|
|
|
|
def __init__(self, dataset_config):
|
|
super(NextOneBestViewDataset, self).__init__(dataset_config)
|
|
self.data_type = dataset_config["data_type"]
|
|
self.source = dataset_config["source"]
|
|
self.gsnet_label_name = dataset_config["gsnet_label"]
|
|
#self.foundation_pose_label_name = dataset_config["foundation_pose_label"]
|
|
self.data_dir = ConfigManager.get("datasets", "general", "data_dir")
|
|
self.score_limit = ConfigManager.get("datasets", "general", "score_limit")
|
|
self.target_pts_num = ConfigManager.get("datasets", "general", "target_pts_num")
|
|
self.scene_pts_num = ConfigManager.get("datasets", "general", "scene_pts_num")
|
|
self.image_size = ConfigManager.get("datasets", "general", "image_size")
|
|
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")
|
|
self.canonical = ConfigManager.get("datasets", "general", "canonical")
|
|
self.small_batch_overfit = ConfigManager.get("settings", "experiment", "small_batch_overfit")
|
|
self.container_path = str(os.path.join(self.data_dir, self.source, "container_set.pickle"))
|
|
self.data_path = str(os.path.join(self.data_dir, self.source, self.data_type))
|
|
self.gsnet_label_path = str(os.path.join(self.data_dir, self.source, self.gsnet_label_name))
|
|
#self.foundation_pose_label_path = str(os.path.join(self.data_dir, self.source, self.foundation_pose_label_name))
|
|
self.scene_list = os.listdir(self.data_path)
|
|
self.task = ConfigManager.get("settings", "experiment", "task")
|
|
self.container_set = self.load_container_set()
|
|
self.data_list = self.get_datalist()
|
|
|
|
|
|
if self.small_batch_overfit:
|
|
small_batch_size = ConfigManager.get("settings", "experiment", "small_batch_size")
|
|
small_batch_times = ConfigManager.get("settings", "experiment", "small_batch_times")
|
|
self.data_list = self.data_list[:small_batch_size] * small_batch_times
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize(self.image_size),
|
|
transforms.CenterCrop(int(self.image_size//14)*14),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=0.5, std=0.2)
|
|
])
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.data_list)
|
|
|
|
def getitem(self, index) -> dict:
|
|
data_pair = self.data_list[index]
|
|
src_path = data_pair[0]["frame_path"]
|
|
dst_path = data_pair[1]["frame_path"]
|
|
target_name = data_pair[0]["object_name"]
|
|
scene_name = data_pair[0]["scene"]
|
|
src_data = self.load_src_data(src_path, target_name, canonical=self.canonical)
|
|
dst_data = self.load_dst_data(dst_path)
|
|
src_rot = src_data["cam_transform"][:3,:3]
|
|
dst_rot = dst_data["cam_transform"][:3,:3]
|
|
delta_rot = np.dot(dst_rot.T, src_rot)
|
|
delta_rot_6d = PoseUtil.matrix_to_rotation_6d_numpy(delta_rot)
|
|
|
|
item_data = {
|
|
"src_path": src_path,
|
|
"target_name": target_name,
|
|
"scene_name": scene_name,
|
|
"data_type": self.data_type,
|
|
"source": self.source,
|
|
"target_pts": src_data["target_pts"].astype(np.float32),
|
|
"scene_pts": src_data["scene_pts"].astype(np.float32),
|
|
"delta_rot_6d": delta_rot_6d.astype(np.float32),
|
|
"src_rot_mat": src_rot.astype(np.float32),
|
|
"dst_rot_mat": dst_rot.astype(np.float32),
|
|
"src_transform": src_data["cam_transform"].astype(np.float32),
|
|
"dst_transform": dst_data["cam_transform"].astype(np.float32),
|
|
}
|
|
# if self.rgb_feat_cache:
|
|
# item_data["rgb_feat"] = src_data["rgb_feat"].astype(np.float32)
|
|
# else:
|
|
# item_data["rgb"] = src_data["rgb"]
|
|
return item_data
|
|
|
|
def load_dst_data(self, frame_path):
|
|
""" camera params """
|
|
cam_transform = OmniUtil.get_transform_mat(frame_path)
|
|
frame_data = {'cam_transform': cam_transform}
|
|
return frame_data
|
|
|
|
def load_src_data(self, frame_path, target_object_name, canonical = False):
|
|
""" pts """
|
|
scene_pts = OmniUtil.get_points(path=frame_path, object_name=OmniUtil.FOREGROUND)
|
|
target_pts = OmniUtil.get_points(
|
|
path=frame_path, object_name=target_object_name
|
|
)
|
|
scene_pts = PclUtil.sample_pcl(scene_pts, self.scene_pts_num)
|
|
target_pts = PclUtil.sample_pcl(target_pts, self.target_pts_num)
|
|
|
|
|
|
""" camera params """
|
|
cam_transform = OmniUtil.get_transform_mat(frame_path)
|
|
if canonical:
|
|
target_pts = PclUtil.cam2canonical(target_pts, cam_transform)
|
|
scene_pts = PclUtil.cam2canonical(scene_pts, cam_transform)
|
|
|
|
frame_data = {
|
|
"target_pts": target_pts,
|
|
"scene_pts": scene_pts,
|
|
"cam_transform": cam_transform,
|
|
}
|
|
|
|
""" rgb """
|
|
# if self.rgb_feat_cache:
|
|
# rgb_feat = OmniUtil.get_rgb_feat(frame_path)
|
|
# frame_data["rgb_feat"] = rgb_feat
|
|
# else:
|
|
# rgb = OmniUtil.get_rgb(frame_path)
|
|
# rgb = Image.fromarray(rgb)
|
|
# rgb = self.transform(rgb)
|
|
# frame_data["rgb"] = rgb
|
|
|
|
return frame_data
|
|
|
|
def load_container_set(self):
|
|
container_list = ['chair_028', 'chair_029', 'chair_026', 'chair_027', 'table_025', 'table_027', 'table_026', 'table_028', 'sofa_014', 'sofa_013', 'picnic_basket_010', 'picnic_basket_011', 'cabinet_009', 'flower_pot_023', 'flower_pot_022', 'flower_pot_021', 'chair_017', 'chair_020', 'chair_012', 'chair_010', 'chair_018', 'chair_025', 'chair_024', 'chair_011', 'chair_001', 'chair_013', 'chair_004', 'chair_021', 'chair_023', 'chair_006', 'chair_014', 'chair_007', 'chair_003', 'chair_009', 'chair_022', 'chair_015', 'chair_016', 'chair_008', 'chair_005', 'chair_019', 'chair_002', 'table_004', 'table_023', 'table_014', 'table_024', 'table_019', 'table_022', 'table_007', 'table_017', 'table_013', 'table_002', 'table_016', 'table_009', 'table_008', 'table_003', 'table_015', 'table_001', 'table_018', 'table_005', 'table_020', 'table_021', 'sofa_001', 'sofa_005', 'sofa_012', 'sofa_009', 'sofa_006', 'sofa_008', 'sofa_011', 'sofa_004', 'sofa_003', 'sofa_002', 'sofa_007', 'sofa_010', 'picnic_basket_005', 'picnic_basket_004', 'picnic_basket_001', 'picnic_basket_008', 'picnic_basket_002', 'picnic_basket_009', 'picnic_basket_006', 'picnic_basket_003', 'picnic_basket_007', 'cabinet_006', 'cabinet_008', 'cabinet_002', 'cabinet_001', 'cabinet_005', 'cabinet_007', 'flower_pot_013', 'flower_pot_005', 'flower_pot_008', 'flower_pot_001', 'flower_pot_003', 'flower_pot_020', 'flower_pot_006', 'flower_pot_012', 'flower_pot_018', 'flower_pot_007', 'flower_pot_002', 'flower_pot_011', 'flower_pot_010', 'flower_pot_016', 'flower_pot_004', 'flower_pot_014', 'flower_pot_017', 'flower_pot_019']
|
|
container_set = set(container_list)
|
|
return container_set
|
|
|
|
def get_ground_object_set(self, scene_name):
|
|
fall_path = os.path.join(self.data_path, scene_name, "fall_objects.pickle")
|
|
with open(fall_path, 'rb') as f:
|
|
fall_objects = pickle.load(f)
|
|
return fall_objects
|
|
|
|
def get_datalist(self):
|
|
if self.task == "object_pose":
|
|
raise NotImplementedError("object_pose task is not supported now.")
|
|
#return self.get_foundation_pose_datalist()
|
|
elif self.task == "grasp_pose":
|
|
return self.get_grasp_pose_datalist()
|
|
else:
|
|
raise ValueError("task must be 'object_pose' or 'grasp_pose'.")
|
|
|
|
'''
|
|
def get_foundation_pose_datalist(self):
|
|
data_list = []
|
|
for scene in self.scene_list:
|
|
scene_path = os.path.join(self.data_path, scene)
|
|
gsnet_label_scene_path = os.path.join(self.foundation_pose_label_path, scene)
|
|
file_list = os.listdir(scene_path)
|
|
scene_frame_list = []
|
|
target_object_set = self.get_target_object_set(scene)
|
|
unseen_object_set = self.get_target_object_set(scene)
|
|
cnt_under = 0
|
|
cnt_above = 0
|
|
limit = 0.002
|
|
for file in file_list:
|
|
if file.startswith("camera_params"):
|
|
frame_index = re.findall(r"\d+", file)[0]
|
|
frame_path = os.path.join(scene_path, frame_index)
|
|
|
|
score_label_path = os.path.join(
|
|
gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index)
|
|
)
|
|
with open(score_label_path, "r") as f:
|
|
score_label = json.load(f)
|
|
for obj_name in score_label.keys():
|
|
if obj_name in target_object_set:
|
|
scene_frame_list.append(
|
|
{
|
|
"frame_path": frame_path,
|
|
"object_name": obj_name,
|
|
"score": score_label[obj_name]["eval_result"]["ADD-S"],
|
|
"scene": scene
|
|
}
|
|
)
|
|
if score_label[obj_name]["eval_result"]["ADD-S"] <= limit:
|
|
cnt_under += 1
|
|
else:
|
|
cnt_above += 1
|
|
print(f"under {limit}: {cnt_under}, above {limit}: {cnt_above}")
|
|
for i in range(len(scene_frame_list)):
|
|
for j in range(i, len(scene_frame_list)):
|
|
fm_i, fm_j = scene_frame_list[i], scene_frame_list[j]
|
|
if fm_i["object_name"] == fm_j["object_name"]:
|
|
bad_view, good_view = None, None
|
|
if fm_i["score"] <= limit < fm_j["score"]:
|
|
good_view, bad_view = fm_i, fm_j
|
|
elif fm_i["score"] > limit >= fm_j["score"]:
|
|
good_view, bad_view = fm_j, fm_i
|
|
if bad_view is None or good_view is None:
|
|
continue
|
|
data_list.append((bad_view, good_view))
|
|
if bad_view["object_name"] in unseen_object_set:
|
|
unseen_object_set.remove(bad_view["object_name"])
|
|
return data_list
|
|
'''
|
|
|
|
def get_grasp_pose_datalist(self):
|
|
data_list = []
|
|
for scene in self.scene_list:
|
|
|
|
scene_path = os.path.join(self.data_path, scene)
|
|
gsnet_label_scene_path = os.path.join(self.gsnet_label_path, scene)
|
|
file_list = os.listdir(scene_path)
|
|
scene_frame_list = []
|
|
ground_object_set = self.get_ground_object_set(scene)
|
|
unseen_object_set = set()
|
|
for file in file_list:
|
|
if file.startswith("camera_params"):
|
|
frame_index = re.findall(r"\d+", file)[0]
|
|
frame_path = os.path.join(scene_path, frame_index)
|
|
|
|
score_label_path = os.path.join(
|
|
gsnet_label_scene_path, OmniUtil.SCORE_LABEL_TEMPLATE.format(frame_index)
|
|
)
|
|
with open(score_label_path, "r") as f:
|
|
score_label = json.load(f)
|
|
for obj_name in score_label["avg_score"].keys():
|
|
if obj_name not in ground_object_set and obj_name not in self.container_set:
|
|
scene_frame_list.append(
|
|
{
|
|
"frame_path": frame_path,
|
|
"object_name": obj_name,
|
|
"score": score_label["avg_score"][obj_name],
|
|
"scene": scene
|
|
}
|
|
)
|
|
unseen_object_set.add(obj_name)
|
|
for i in range(len(scene_frame_list)):
|
|
for j in range(i, len(scene_frame_list)):
|
|
fm_i, fm_j = scene_frame_list[i], scene_frame_list[j]
|
|
if fm_i["object_name"] == fm_j["object_name"]:
|
|
bad_view, good_view = None, None
|
|
if fm_i["score"] <= self.score_limit < fm_j["score"]:
|
|
bad_view, good_view = fm_i, fm_j
|
|
elif fm_i["score"] > self.score_limit >= fm_j["score"]:
|
|
bad_view, good_view = fm_j, fm_i
|
|
if bad_view is None or good_view is None:
|
|
continue
|
|
sample_prob = ((max(0,good_view["score"] - bad_view["score"]))/0.3)**2
|
|
if np.random.rand() > sample_prob:
|
|
continue
|
|
data_list.append((bad_view, good_view))
|
|
if bad_view["object_name"] in unseen_object_set:
|
|
unseen_object_set.remove(bad_view["object_name"])
|
|
for obj_name in unseen_object_set:
|
|
views = []
|
|
for frame in scene_frame_list:
|
|
if frame["object_name"] == obj_name:
|
|
views.append(frame)
|
|
sorted_views = sorted(views, key=lambda x: x["score"], reverse=True)
|
|
total_view_num = len(sorted_views)
|
|
good_view_num = int(total_view_num * 0.1)
|
|
good_views = sorted_views[:good_view_num]
|
|
bad_views = sorted_views[good_view_num:]
|
|
filtered_good_view = []
|
|
filtered_bad_view = bad_views
|
|
for good_view in good_views:
|
|
if good_view["score"] >= 0.01:
|
|
filtered_good_view.append(good_view)
|
|
else:
|
|
filtered_bad_view.append(good_view)
|
|
for good_view in filtered_good_view:
|
|
for bad_view in filtered_bad_view:
|
|
data_list.append((bad_view, good_view))
|
|
return data_list
|