import torch import asyncio import numpy as np from concurrent.futures import ThreadPoolExecutor from utils.pose_util import PoseUtil from runners.preprocessors.grasping.GSNet_preprocessor import GSNetPreprocessor from torch.utils.data import Dataset from annotations.singleton import singleton from baselines.grasping.GSNet.models.graspnet import GraspNet from baselines.grasping.GSNet.dataset.graspnet_dataset import minkowski_collate_fn from configs.config import ConfigManager from utils.view_util import ViewUtil import annotations.stereotype as stereotype from tqdm import tqdm class GSNetInferenceDataset(Dataset): def __init__( self, view_data_list, scene_pts_num=15000, voxel_size=0.005, ): self.scene_pts_num = scene_pts_num self.voxel_size = voxel_size self.view_data_list = view_data_list def __len__(self): return len(self.view_data_list) def __getitem__(self, index): view_data = self.view_data_list[index] object_name, scene_pts, obj_pcl_dict = view_data ret_dict = { "frame_path": index, "point_clouds": scene_pts.astype(np.float32), "coors": scene_pts.astype(np.float32) / self.voxel_size, "feats": np.ones_like(scene_pts).astype(np.float32), "obj_pcl_dict": obj_pcl_dict, "object_name":object_name, } return ret_dict @staticmethod def sample_pcl(pcl, n_pts=1024): indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts) return pcl[indices, :] @singleton class GSNetEvaluator(GSNetPreprocessor): def __init__(self): self.model = self.get_model(model_path=ConfigManager.get("settings", "experiment","grasp_model_path")) def get_dataloader(self, view_data_list): def my_worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) dataset = GSNetInferenceDataset(view_data_list) dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn, ) return dataloader def get_model(self, seed_feat_dim=512, model_path="default"): model = GraspNet(seed_feat_dim=seed_feat_dim, is_training=False) model.to("cuda") checkpoint = torch.load(model_path) model.load_state_dict(checkpoint["model_state_dict"]) start_epoch = checkpoint["epoch"] print("-> loaded checkpoint %s (epoch: %d)" % (model_path, start_epoch)) model.eval() return model def get_transformed_mat(src_mat, delta_rot,target_center_w): src_rot = src_mat[:3, :3] dst_rot = src_rot @ delta_rot.T dst_mat = torch.eye(4).to(dst_rot.device) dst_mat[:3, :3] = dst_rot distance = torch.norm(target_center_w - src_mat[:3, 3]) z_axis_camera = dst_rot[:3, 2].reshape(-1) new_camera_position_w = target_center_w - distance * z_axis_camera dst_mat[:3, 3] = new_camera_position_w return dst_mat def get_score_from_processed_data(processed_data, object_name_list): score = 0 cnt = 0 for key in processed_data: object_name = object_name_list[key] if object_name not in processed_data[key]["avg_score"]: avg_score = 0 else: avg_score = processed_data[key]["avg_score"][object_name] score += avg_score cnt += 1 return score / cnt def sample_points(points, target_num_points): num_points = points.shape[0] if num_points == 0: return np.zeros((target_num_points, points.shape[1])) if num_points > target_num_points: indices = np.random.choice(num_points, target_num_points, replace=False) else: indices = np.random.choice(num_points, target_num_points, replace=True) return points[indices] def sample_dict_to_target_points(croped_pts_dict, total_points=15000): all_sampled_points = [] sampled_pts_dict = {} total_existing_points = sum([pts.shape[0] for pts in croped_pts_dict.values() if pts.shape[0] > 0]) if total_existing_points == 0: for name, pts in croped_pts_dict.items(): sampled_pts_dict[name] = pts return np.zeros((total_points, 3)), sampled_pts_dict if total_existing_points > total_points: ratios = {name: len(pts) / total_existing_points for name, pts in croped_pts_dict.items() if pts.shape[0] > 0} target_num_points = {name: int(ratio * total_points) for name, ratio in ratios.items()} remaining_points = total_points - sum(target_num_points.values()) for name in target_num_points.keys(): if remaining_points > 0: target_num_points[name] += 1 remaining_points -= 1 else: target_num_points = {name: len(pts) for name, pts in croped_pts_dict.items()} remaining_points = total_points - total_existing_points additional_points = np.random.choice([name for name, pts in croped_pts_dict.items() if pts.shape[0] > 0], remaining_points, replace=True) for name in additional_points: target_num_points[name] += 1 for name, pts in croped_pts_dict.items(): if pts.shape[0] == 0: sampled_pts_dict[name] = pts continue sampled_pts = sample_points(pts, target_num_points[name]) sampled_pts_dict[name] = sampled_pts all_sampled_points.append(sampled_pts) if len(all_sampled_points) > 0: sampled_scene_pts = np.concatenate(all_sampled_points, axis=0) else: sampled_scene_pts = np.zeros((total_points, 3)) return sampled_scene_pts, sampled_pts_dict def extract_view_pts_from_view(obj_name, rgb, depth, seg, seg_labels, camera_params): pts_dict = ViewUtil.get_pts_dict(depth, seg, seg_labels, camera_params) obj_center = ViewUtil.get_object_center_from_pts_dict(obj_name, pts_dict) croped_pts_dict = ViewUtil.crop_pts_dict(pts_dict, obj_center, radius=0.2) sampled_scene_pts, sampled_pts_dict = sample_dict_to_target_points(croped_pts_dict) return obj_name,sampled_scene_pts, sampled_pts_dict async def async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list, all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port): all_src_view_data_list = [] all_part_gt_dst_view_data_list = [] all_full_gt_dst_view_data_list = [] all_est_dst_view_data_list = [] with ThreadPoolExecutor() as executor: loop = asyncio.get_event_loop() for i in tqdm(range(total), desc="----Processing items", ncols=100): src_mat = all_src_mat_list[i] part_gt_dst_mat = all_part_gt_dst_mat_list[i] full_gt_dst_mat = all_full_gt_dst_mat_list[i] est_dst_mat = all_est_dst_mat_list[i] source = all_source_list[i] data_type = all_data_type_list[i] scene_name = all_scene_name_list[i] obj_name = all_object_name_list[i] src_view_future = loop.run_in_executor(executor, ViewUtil.get_view, src_mat, source, data_type, scene_name, web_server_port) part_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, part_gt_dst_mat, source, data_type, scene_name, web_server_port + 1) full_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, full_gt_dst_mat, source, data_type, scene_name, web_server_port + 2) est_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, est_dst_mat, source, data_type, scene_name, web_server_port + 3) src_view_data, part_gt_dst_view_data, full_gt_dst_view_data, est_dst_view_data = await asyncio.gather( src_view_future, part_gt_dst_view_future, full_gt_dst_view_future, est_dst_view_future ) all_src_view_data_list.append(extract_view_pts_from_view(obj_name, *src_view_data)) all_part_gt_dst_view_data_list.append(extract_view_pts_from_view(obj_name, *part_gt_dst_view_data)) all_full_gt_dst_view_data_list.append(extract_view_pts_from_view(obj_name, *full_gt_dst_view_data)) all_est_dst_view_data_list.append(extract_view_pts_from_view(obj_name, *est_dst_view_data)) return (all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list) @stereotype.evaluation_method("grasp_pose_improvement") def evaluate(output_list, data_list): evaluator = GSNetEvaluator() web_server_port = ConfigManager.get("settings", "experiment", "web_api", "port") all_src_mat_list = [] all_part_gt_dst_mat_list = [] all_full_gt_dst_mat_list = [] all_est_dst_mat_list = [] all_scene_name_list = [] all_object_name_list = [] all_source_list = [] all_data_type_list = [] all_target_center_w_list = [] for output, data in zip(output_list, data_list): gt_delta_rot_6d_list = data["delta_rot_6d"] est_delta_rot_6d_list = output["estimated_delta_rot_6d"] src_mat_list = data["src_transform"] gt_mat_list = data["dst_transform"] scene_name_list = data["scene_name"] object_name_list = data["target_name"] target_pts_list = data["target_pts"] source_list = data["source"] data_type_list = data["data_type"] target_center_c_list = torch.mean(target_pts_list, axis=1) target_center_w_list = torch.bmm(src_mat_list[:,:3,:3], target_center_c_list.unsqueeze(2)).squeeze(2) + src_mat_list[:, :3, 3] gt_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_delta_rot_6d_list) est_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(est_delta_rot_6d_list) for i in range(len(scene_name_list)): src_mat = src_mat_list[i] target_center_w = target_center_w_list[i] gt_delta_rot_mat = gt_delta_rot_mat_list[i] est_delta_rot_mat = est_delta_rot_mat_list[i] part_gt_dst_mat = get_transformed_mat(src_mat, gt_delta_rot_mat,target_center_w) est_dst_mat = get_transformed_mat(src_mat, est_delta_rot_mat,target_center_w) all_src_mat_list.append(src_mat) all_part_gt_dst_mat_list.append(part_gt_dst_mat) all_full_gt_dst_mat_list.append(gt_mat_list[i]) all_est_dst_mat_list.append(est_dst_mat) all_scene_name_list.append(scene_name_list[i]) all_object_name_list.append(object_name_list[i]) all_source_list.append(source_list[i]) all_data_type_list.append(data_type_list[i]) all_target_center_w_list.append(target_center_w) all_src_view_data_list = [] all_part_gt_dst_view_data_list = [] all_full_gt_dst_view_data_list = [] all_est_dst_view_data_list = [] total = len(all_src_mat_list) loop = asyncio.get_event_loop() all_view_data_list = loop.run_until_complete(async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list, all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port)) all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list = all_view_data_list src_dataloader = evaluator.get_dataloader(all_src_view_data_list) part_gt_dst_dataloader = evaluator.get_dataloader(all_part_gt_dst_view_data_list) full_gt_dst_dataloader = evaluator.get_dataloader(all_full_gt_dst_view_data_list) est_dst_dataloader = evaluator.get_dataloader(all_est_dst_view_data_list) src_predicted_data = evaluator.prediction(evaluator.model, src_dataloader, require_gripper=False) part_gt_dst_predicted_data = evaluator.prediction(evaluator.model, part_gt_dst_dataloader,require_gripper=False) full_gt_dst_predicted_data = evaluator.prediction(evaluator.model, full_gt_dst_dataloader,require_gripper=False) est_dst_predicted_data = evaluator.prediction(evaluator.model, est_dst_dataloader,require_gripper=False) src_processed_data = evaluator.preprocess(src_predicted_data, require_gripper=False) part_gt_dst_processed_data = evaluator.preprocess(part_gt_dst_predicted_data, require_gripper=False) full_gt_dst_processed_data = evaluator.preprocess(full_gt_dst_predicted_data, require_gripper=False) est_dst_processed_data = evaluator.preprocess(est_dst_predicted_data, require_gripper=False) src_score = get_score_from_processed_data( src_processed_data, all_object_name_list ) part_gt_dst_score = get_score_from_processed_data( part_gt_dst_processed_data, all_object_name_list ) full_gt_dst_score = get_score_from_processed_data( full_gt_dst_processed_data, all_object_name_list ) est_dst_score = get_score_from_processed_data( est_dst_processed_data, all_object_name_list ) score_improvement = est_dst_score - src_score score_diff_to_full_gt = full_gt_dst_score - est_dst_score score_diff_to_part_gt = part_gt_dst_score - est_dst_score look_at_center_score_diff = full_gt_dst_score - part_gt_dst_score results = { "scalars": { "grasp_pose_score": { "src": src_score, "part_gt_dst": part_gt_dst_score, "full_gt_dst": full_gt_dst_score, "est_dst": est_dst_score, }, "grasp_pose_score_improvement": score_improvement, "grasp_pose_score_diff_to_full_gt": score_diff_to_full_gt, "grasp_pose_score_diff_to_part_gt": score_diff_to_part_gt, "grasp_pose_look_at_center_score_diff": look_at_center_score_diff, } } return results if __name__ == "__main__": pass