306 lines
14 KiB
Python
306 lines
14 KiB
Python
|
import torch
|
||
|
import asyncio
|
||
|
import numpy as np
|
||
|
from concurrent.futures import ThreadPoolExecutor
|
||
|
|
||
|
from utils.pose_util import PoseUtil
|
||
|
from runners.preprocessors.grasping.GSNet_preprocessor import GSNetPreprocessor
|
||
|
from torch.utils.data import Dataset
|
||
|
from annotations.singleton import singleton
|
||
|
from baselines.grasping.GSNet.models.graspnet import GraspNet
|
||
|
from baselines.grasping.GSNet.dataset.graspnet_dataset import minkowski_collate_fn
|
||
|
from configs.config import ConfigManager
|
||
|
from utils.view_util import ViewUtil
|
||
|
import annotations.stereotype as stereotype
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
class GSNetInferenceDataset(Dataset):
|
||
|
def __init__(
|
||
|
self,
|
||
|
view_data_list,
|
||
|
scene_pts_num=15000,
|
||
|
voxel_size=0.005,
|
||
|
):
|
||
|
self.scene_pts_num = scene_pts_num
|
||
|
self.voxel_size = voxel_size
|
||
|
self.view_data_list = view_data_list
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.view_data_list)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
view_data = self.view_data_list[index]
|
||
|
object_name, scene_pts, obj_pcl_dict = view_data
|
||
|
ret_dict = {
|
||
|
"frame_path": index,
|
||
|
"point_clouds": scene_pts.astype(np.float32),
|
||
|
"coors": scene_pts.astype(np.float32) / self.voxel_size,
|
||
|
"feats": np.ones_like(scene_pts).astype(np.float32),
|
||
|
"obj_pcl_dict": obj_pcl_dict,
|
||
|
"object_name":object_name,
|
||
|
}
|
||
|
return ret_dict
|
||
|
|
||
|
@staticmethod
|
||
|
def sample_pcl(pcl, n_pts=1024):
|
||
|
indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts)
|
||
|
return pcl[indices, :]
|
||
|
|
||
|
|
||
|
@singleton
|
||
|
class GSNetEvaluator(GSNetPreprocessor):
|
||
|
def __init__(self):
|
||
|
self.model = self.get_model(model_path=ConfigManager.get("settings", "experiment","grasp_model_path"))
|
||
|
|
||
|
def get_dataloader(self, view_data_list):
|
||
|
def my_worker_init_fn(worker_id):
|
||
|
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||
|
dataset = GSNetInferenceDataset(view_data_list)
|
||
|
dataloader = torch.utils.data.DataLoader(
|
||
|
dataset,
|
||
|
batch_size=1,
|
||
|
shuffle=False,
|
||
|
num_workers=0,
|
||
|
worker_init_fn=my_worker_init_fn,
|
||
|
collate_fn=minkowski_collate_fn,
|
||
|
)
|
||
|
return dataloader
|
||
|
|
||
|
def get_model(self, seed_feat_dim=512, model_path="default"):
|
||
|
model = GraspNet(seed_feat_dim=seed_feat_dim, is_training=False)
|
||
|
model.to("cuda")
|
||
|
checkpoint = torch.load(model_path)
|
||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||
|
start_epoch = checkpoint["epoch"]
|
||
|
print("-> loaded checkpoint %s (epoch: %d)" % (model_path, start_epoch))
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
|
||
|
def get_transformed_mat(src_mat, delta_rot,target_center_w):
|
||
|
src_rot = src_mat[:3, :3]
|
||
|
dst_rot = src_rot @ delta_rot.T
|
||
|
dst_mat = torch.eye(4).to(dst_rot.device)
|
||
|
dst_mat[:3, :3] = dst_rot
|
||
|
distance = torch.norm(target_center_w - src_mat[:3, 3])
|
||
|
z_axis_camera = dst_rot[:3, 2].reshape(-1)
|
||
|
new_camera_position_w = target_center_w - distance * z_axis_camera
|
||
|
dst_mat[:3, 3] = new_camera_position_w
|
||
|
return dst_mat
|
||
|
|
||
|
|
||
|
|
||
|
def get_score_from_processed_data(processed_data, object_name_list):
|
||
|
score = 0
|
||
|
cnt = 0
|
||
|
for key in processed_data:
|
||
|
object_name = object_name_list[key]
|
||
|
if object_name not in processed_data[key]["avg_score"]:
|
||
|
avg_score = 0
|
||
|
else:
|
||
|
avg_score = processed_data[key]["avg_score"][object_name]
|
||
|
score += avg_score
|
||
|
cnt += 1
|
||
|
return score / cnt
|
||
|
|
||
|
def sample_points(points, target_num_points):
|
||
|
num_points = points.shape[0]
|
||
|
if num_points == 0:
|
||
|
return np.zeros((target_num_points, points.shape[1]))
|
||
|
if num_points > target_num_points:
|
||
|
indices = np.random.choice(num_points, target_num_points, replace=False)
|
||
|
else:
|
||
|
indices = np.random.choice(num_points, target_num_points, replace=True)
|
||
|
return points[indices]
|
||
|
|
||
|
def sample_dict_to_target_points(croped_pts_dict, total_points=15000):
|
||
|
all_sampled_points = []
|
||
|
sampled_pts_dict = {}
|
||
|
total_existing_points = sum([pts.shape[0] for pts in croped_pts_dict.values() if pts.shape[0] > 0])
|
||
|
|
||
|
if total_existing_points == 0:
|
||
|
for name, pts in croped_pts_dict.items():
|
||
|
sampled_pts_dict[name] = pts
|
||
|
return np.zeros((total_points, 3)), sampled_pts_dict
|
||
|
|
||
|
if total_existing_points > total_points:
|
||
|
ratios = {name: len(pts) / total_existing_points for name, pts in croped_pts_dict.items() if pts.shape[0] > 0}
|
||
|
target_num_points = {name: int(ratio * total_points) for name, ratio in ratios.items()}
|
||
|
remaining_points = total_points - sum(target_num_points.values())
|
||
|
for name in target_num_points.keys():
|
||
|
if remaining_points > 0:
|
||
|
target_num_points[name] += 1
|
||
|
remaining_points -= 1
|
||
|
else:
|
||
|
target_num_points = {name: len(pts) for name, pts in croped_pts_dict.items()}
|
||
|
remaining_points = total_points - total_existing_points
|
||
|
additional_points = np.random.choice([name for name, pts in croped_pts_dict.items() if pts.shape[0] > 0], remaining_points, replace=True)
|
||
|
for name in additional_points:
|
||
|
target_num_points[name] += 1
|
||
|
|
||
|
for name, pts in croped_pts_dict.items():
|
||
|
if pts.shape[0] == 0:
|
||
|
sampled_pts_dict[name] = pts
|
||
|
continue
|
||
|
sampled_pts = sample_points(pts, target_num_points[name])
|
||
|
sampled_pts_dict[name] = sampled_pts
|
||
|
all_sampled_points.append(sampled_pts)
|
||
|
|
||
|
if len(all_sampled_points) > 0:
|
||
|
sampled_scene_pts = np.concatenate(all_sampled_points, axis=0)
|
||
|
else:
|
||
|
sampled_scene_pts = np.zeros((total_points, 3))
|
||
|
return sampled_scene_pts, sampled_pts_dict
|
||
|
|
||
|
|
||
|
def extract_view_pts_from_view(obj_name, rgb, depth, seg, seg_labels, camera_params):
|
||
|
pts_dict = ViewUtil.get_pts_dict(depth, seg, seg_labels, camera_params)
|
||
|
obj_center = ViewUtil.get_object_center_from_pts_dict(obj_name, pts_dict)
|
||
|
croped_pts_dict = ViewUtil.crop_pts_dict(pts_dict, obj_center, radius=0.2)
|
||
|
|
||
|
sampled_scene_pts, sampled_pts_dict = sample_dict_to_target_points(croped_pts_dict)
|
||
|
|
||
|
return obj_name,sampled_scene_pts, sampled_pts_dict
|
||
|
|
||
|
async def async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list,
|
||
|
all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port):
|
||
|
|
||
|
all_src_view_data_list = []
|
||
|
all_part_gt_dst_view_data_list = []
|
||
|
all_full_gt_dst_view_data_list = []
|
||
|
all_est_dst_view_data_list = []
|
||
|
|
||
|
with ThreadPoolExecutor() as executor:
|
||
|
loop = asyncio.get_event_loop()
|
||
|
for i in tqdm(range(total), desc="----Processing items", ncols=100):
|
||
|
src_mat = all_src_mat_list[i]
|
||
|
part_gt_dst_mat = all_part_gt_dst_mat_list[i]
|
||
|
full_gt_dst_mat = all_full_gt_dst_mat_list[i]
|
||
|
est_dst_mat = all_est_dst_mat_list[i]
|
||
|
source = all_source_list[i]
|
||
|
data_type = all_data_type_list[i]
|
||
|
scene_name = all_scene_name_list[i]
|
||
|
obj_name = all_object_name_list[i]
|
||
|
|
||
|
src_view_future = loop.run_in_executor(executor, ViewUtil.get_view, src_mat, source, data_type, scene_name, web_server_port)
|
||
|
part_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, part_gt_dst_mat, source, data_type, scene_name, web_server_port + 1)
|
||
|
full_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, full_gt_dst_mat, source, data_type, scene_name, web_server_port + 2)
|
||
|
est_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, est_dst_mat, source, data_type, scene_name, web_server_port + 3)
|
||
|
|
||
|
src_view_data, part_gt_dst_view_data, full_gt_dst_view_data, est_dst_view_data = await asyncio.gather(
|
||
|
src_view_future, part_gt_dst_view_future, full_gt_dst_view_future, est_dst_view_future
|
||
|
)
|
||
|
|
||
|
all_src_view_data_list.append(extract_view_pts_from_view(obj_name, *src_view_data))
|
||
|
all_part_gt_dst_view_data_list.append(extract_view_pts_from_view(obj_name, *part_gt_dst_view_data))
|
||
|
all_full_gt_dst_view_data_list.append(extract_view_pts_from_view(obj_name, *full_gt_dst_view_data))
|
||
|
all_est_dst_view_data_list.append(extract_view_pts_from_view(obj_name, *est_dst_view_data))
|
||
|
|
||
|
return (all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list)
|
||
|
|
||
|
@stereotype.evaluation_method("grasp_pose_improvement")
|
||
|
def evaluate(output_list, data_list):
|
||
|
evaluator = GSNetEvaluator()
|
||
|
web_server_port = ConfigManager.get("settings", "experiment", "web_api", "port")
|
||
|
all_src_mat_list = []
|
||
|
all_part_gt_dst_mat_list = []
|
||
|
all_full_gt_dst_mat_list = []
|
||
|
all_est_dst_mat_list = []
|
||
|
all_scene_name_list = []
|
||
|
all_object_name_list = []
|
||
|
all_source_list = []
|
||
|
all_data_type_list = []
|
||
|
all_target_center_w_list = []
|
||
|
for output, data in zip(output_list, data_list):
|
||
|
gt_delta_rot_6d_list = data["delta_rot_6d"]
|
||
|
est_delta_rot_6d_list = output["estimated_delta_rot_6d"]
|
||
|
src_mat_list = data["src_transform"]
|
||
|
gt_mat_list = data["dst_transform"]
|
||
|
scene_name_list = data["scene_name"]
|
||
|
object_name_list = data["target_name"]
|
||
|
target_pts_list = data["target_pts"]
|
||
|
source_list = data["source"]
|
||
|
data_type_list = data["data_type"]
|
||
|
target_center_c_list = torch.mean(target_pts_list, axis=1)
|
||
|
target_center_w_list = torch.bmm(src_mat_list[:,:3,:3], target_center_c_list.unsqueeze(2)).squeeze(2) + src_mat_list[:, :3, 3]
|
||
|
gt_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_delta_rot_6d_list)
|
||
|
est_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(est_delta_rot_6d_list)
|
||
|
for i in range(len(scene_name_list)):
|
||
|
src_mat = src_mat_list[i]
|
||
|
target_center_w = target_center_w_list[i]
|
||
|
gt_delta_rot_mat = gt_delta_rot_mat_list[i]
|
||
|
est_delta_rot_mat = est_delta_rot_mat_list[i]
|
||
|
part_gt_dst_mat = get_transformed_mat(src_mat, gt_delta_rot_mat,target_center_w)
|
||
|
est_dst_mat = get_transformed_mat(src_mat, est_delta_rot_mat,target_center_w)
|
||
|
all_src_mat_list.append(src_mat)
|
||
|
all_part_gt_dst_mat_list.append(part_gt_dst_mat)
|
||
|
all_full_gt_dst_mat_list.append(gt_mat_list[i])
|
||
|
all_est_dst_mat_list.append(est_dst_mat)
|
||
|
all_scene_name_list.append(scene_name_list[i])
|
||
|
all_object_name_list.append(object_name_list[i])
|
||
|
all_source_list.append(source_list[i])
|
||
|
all_data_type_list.append(data_type_list[i])
|
||
|
all_target_center_w_list.append(target_center_w)
|
||
|
|
||
|
all_src_view_data_list = []
|
||
|
all_part_gt_dst_view_data_list = []
|
||
|
all_full_gt_dst_view_data_list = []
|
||
|
all_est_dst_view_data_list = []
|
||
|
total = len(all_src_mat_list)
|
||
|
|
||
|
loop = asyncio.get_event_loop()
|
||
|
all_view_data_list = loop.run_until_complete(async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list,
|
||
|
all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port))
|
||
|
all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list = all_view_data_list
|
||
|
|
||
|
src_dataloader = evaluator.get_dataloader(all_src_view_data_list)
|
||
|
part_gt_dst_dataloader = evaluator.get_dataloader(all_part_gt_dst_view_data_list)
|
||
|
full_gt_dst_dataloader = evaluator.get_dataloader(all_full_gt_dst_view_data_list)
|
||
|
est_dst_dataloader = evaluator.get_dataloader(all_est_dst_view_data_list)
|
||
|
|
||
|
src_predicted_data = evaluator.prediction(evaluator.model, src_dataloader, require_gripper=False)
|
||
|
part_gt_dst_predicted_data = evaluator.prediction(evaluator.model, part_gt_dst_dataloader,require_gripper=False)
|
||
|
full_gt_dst_predicted_data = evaluator.prediction(evaluator.model, full_gt_dst_dataloader,require_gripper=False)
|
||
|
est_dst_predicted_data = evaluator.prediction(evaluator.model, est_dst_dataloader,require_gripper=False)
|
||
|
|
||
|
|
||
|
src_processed_data = evaluator.preprocess(src_predicted_data, require_gripper=False)
|
||
|
part_gt_dst_processed_data = evaluator.preprocess(part_gt_dst_predicted_data, require_gripper=False)
|
||
|
full_gt_dst_processed_data = evaluator.preprocess(full_gt_dst_predicted_data, require_gripper=False)
|
||
|
est_dst_processed_data = evaluator.preprocess(est_dst_predicted_data, require_gripper=False)
|
||
|
src_score = get_score_from_processed_data(
|
||
|
src_processed_data, all_object_name_list
|
||
|
)
|
||
|
part_gt_dst_score = get_score_from_processed_data(
|
||
|
part_gt_dst_processed_data, all_object_name_list
|
||
|
)
|
||
|
full_gt_dst_score = get_score_from_processed_data(
|
||
|
full_gt_dst_processed_data, all_object_name_list
|
||
|
)
|
||
|
est_dst_score = get_score_from_processed_data(
|
||
|
est_dst_processed_data, all_object_name_list
|
||
|
)
|
||
|
|
||
|
score_improvement = est_dst_score - src_score
|
||
|
score_diff_to_full_gt = full_gt_dst_score - est_dst_score
|
||
|
score_diff_to_part_gt = part_gt_dst_score - est_dst_score
|
||
|
look_at_center_score_diff = full_gt_dst_score - part_gt_dst_score
|
||
|
|
||
|
results = {
|
||
|
"scalars": {
|
||
|
"grasp_pose_score": {
|
||
|
"src": src_score,
|
||
|
"part_gt_dst": part_gt_dst_score,
|
||
|
"full_gt_dst": full_gt_dst_score,
|
||
|
"est_dst": est_dst_score,
|
||
|
},
|
||
|
"grasp_pose_score_improvement": score_improvement,
|
||
|
"grasp_pose_score_diff_to_full_gt": score_diff_to_full_gt,
|
||
|
"grasp_pose_score_diff_to_part_gt": score_diff_to_part_gt,
|
||
|
"grasp_pose_look_at_center_score_diff": look_at_center_score_diff,
|
||
|
}
|
||
|
}
|
||
|
return results
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
pass
|