import os import re import sys import numpy as np import torch import open3d as o3d from torch.utils.data import DataLoader path = os.path.abspath(__file__) for i in range(4): path = os.path.dirname(path) PROJECT_ROOT = path sys.path.append(PROJECT_ROOT) GSNET_PROJECT_ROOT = os.path.join(PROJECT_ROOT, "baselines/grasping/GSNet") sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "pointnet2")) sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "utils")) sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "models")) sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "dataset")) from utils.omni_util import OmniUtil from utils.view_util import ViewUtil from runners.preprocessors.grasping.abstract_grasping_preprocessor import GraspingPreprocessor from configs.config import ConfigManager from baselines.grasping.GSNet.models.graspnet import GraspNet from baselines.grasping.GSNet.graspnetAPI.graspnetAPI.graspnet_eval import GraspGroup from baselines.grasping.GSNet.dataset.graspnet_dataset import minkowski_collate_fn from torch.utils.data import Dataset class GSNetInferenceDataset(Dataset): CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json" DISTANCE_TEMPLATE = "distance_to_camera_{}.npy" RGB_TEMPLATE = "rgb_{}.png" MASK_TEMPLATE = "semantic_segmentation_{}.png" MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json" def __init__( self, source="nbv1", data_type="sample", data_dir="/mnt/h/AI/Datasets", scene_pts_num=15000, voxel_size=0.005, ): self.data_dir = data_dir self.scene_pts_num = scene_pts_num self.data_path = str(os.path.join(self.data_dir, source, data_type)) self.scene_list = os.listdir(self.data_path) self.data_list = self.get_datalist() self.voxel_size = voxel_size def __len__(self): return len(self.data_list) def __getitem__(self, index): frame_path, target = self.data_list[index] frame_data = self.load_frame_data(frame_path=frame_path, object_name=target) return frame_data def get_datalist(self): scene_frame_list = [] for scene in self.scene_list: scene_path = os.path.join(self.data_path, scene) file_list = os.listdir(scene_path) for file in file_list: if file.startswith("camera_params"): frame_index = re.findall(r"\d+", file)[0] frame_path = os.path.join(scene_path, frame_index) target_list = OmniUtil.get_object_list(frame_path) for target in target_list: scene_frame_list.append((frame_path,target)) if len(target_list) == 0: scene_frame_list.append((frame_path, None)) print("Scene: ", scene, " has ", len(scene_frame_list), " frames") return scene_frame_list def load_frame_data(self, frame_path, object_name): try: target_list = OmniUtil.get_object_list(path=frame_path, contains_non_obj=True) _, obj_pcl_dict = OmniUtil.get_segmented_points( path=frame_path, target_list=target_list ) obj_center = ViewUtil.get_object_center_from_pts_dict(object_name, obj_pcl_dict) croped_pts_dict = ViewUtil.crop_pts_dict(obj_pcl_dict, obj_center, radius=0.2) sampled_scene_pts, sampled_pts_dict = GSNetInferenceDataset.sample_dict_to_target_points(croped_pts_dict) ret_dict = { "frame_path": frame_path, "point_clouds": sampled_scene_pts.astype(np.float32), "coors": sampled_scene_pts.astype(np.float32) / self.voxel_size, "feats": np.ones_like(sampled_scene_pts).astype(np.float32), "obj_pcl_dict": sampled_pts_dict, "object_name": object_name, } except Exception as e: print("Error in loading frame data: ", e) ret_dict = { "frame_path": frame_path, "point_clouds": np.zeros((self.scene_pts_num, 3)).astype(np.float32), "coors": np.zeros((self.scene_pts_num, 3)).astype(np.float32), "feats": np.ones((self.scene_pts_num, 3)).astype(np.float32), "obj_pcl_dict": {}, "object_name": object_name, "error": True } return ret_dict def sample_points(points, target_num_points): num_points = points.shape[0] if num_points == 0: return np.zeros((target_num_points, points.shape[1])) if num_points > target_num_points: indices = np.random.choice(num_points, target_num_points, replace=False) else: indices = np.random.choice(num_points, target_num_points, replace=True) return points[indices] def sample_dict_to_target_points(croped_pts_dict, total_points=15000): all_sampled_points = [] sampled_pts_dict = {} total_existing_points = sum([pts.shape[0] for pts in croped_pts_dict.values() if pts.shape[0] > 0]) if total_existing_points > total_points: ratios = {name: len(pts) / total_existing_points for name, pts in croped_pts_dict.items() if pts.shape[0] > 0} target_num_points = {name: int(ratio * total_points) for name, ratio in ratios.items()} remaining_points = total_points - sum(target_num_points.values()) for name in target_num_points.keys(): if remaining_points > 0: target_num_points[name] += 1 remaining_points -= 1 else: target_num_points = {name: len(pts) for name, pts in croped_pts_dict.items()} remaining_points = total_points - total_existing_points additional_points = np.random.choice([name for name, pts in croped_pts_dict.items() if pts.shape[0] > 0], remaining_points, replace=True) for name in additional_points: target_num_points[name] += 1 for name, pts in croped_pts_dict.items(): if pts.shape[0] == 0: sampled_pts_dict[name] = pts continue sampled_pts = GSNetInferenceDataset.sample_points(pts, target_num_points[name]) sampled_pts_dict[name] = sampled_pts all_sampled_points.append(sampled_pts) if len(all_sampled_points) > 0: sampled_scene_pts = np.concatenate(all_sampled_points, axis=0) else: sampled_scene_pts = np.zeros((total_points, 3)) return sampled_scene_pts, sampled_pts_dict @staticmethod def sample_pcl(pcl, n_pts=1024): indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts) return pcl[indices, :] class GSNetPreprocessor(GraspingPreprocessor): GRASP_MAX_WIDTH = 0.1 GRASPNESS_THRESHOLD = 0.1 NUM_VIEW = 300 NUM_ANGLE = 12 NUM_DEPTH = 4 M_POINT = 1024 def __init__(self, config_path): super().__init__(config_path) def get_dataloader(self, dataset_config): def my_worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) dataset = GSNetInferenceDataset( source=dataset_config["source"], data_type=dataset_config["data_type"], data_dir=dataset_config["data_dir"], scene_pts_num=dataset_config["scene_pts_num"], voxel_size=dataset_config["voxel_size"], ) print("Test dataset length: ", len(dataset)) dataloader = DataLoader( dataset, batch_size=dataset_config["batch_size"], shuffle=False, num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn, ) print("Test dataloader length: ", len(dataloader)) return dataloader def get_model(self, model_config=None): model = GraspNet(seed_feat_dim=model_config["general"]["seed_feat_dim"], is_training=False) model.to("cuda") checkpoint = torch.load(model_config["general"]["checkpoint_path"]) model.load_state_dict(checkpoint["model_state_dict"]) start_epoch = checkpoint["epoch"] print( "-> loaded checkpoint %s (epoch: %d)" % (model_config["general"]["checkpoint_path"], start_epoch) ) model.eval() return model def prediction(self, model, dataloader, require_gripper=False, top_k=10): preds = {} for idx, batch_data in enumerate(dataloader): try: if "error" in batch_data: frame_path = batch_data["frame_path"][0] object_name = batch_data["object_name"][0] preds[frame_path] = {object_name: None} print("No graspable points found at frame: ", frame_path) continue print("Processing batch: ", idx, "/", len(dataloader)) for key in batch_data: if "list" in key: for i in range(len(batch_data[key])): for j in range(len(batch_data[key][i])): batch_data[key][i][j] = batch_data[key][i][j].to("cuda") elif not isinstance(batch_data[key], (list)): batch_data[key] = batch_data[key].to("cuda") with torch.no_grad(): end_points = model(batch_data) if end_points is None: frame_path = batch_data["frame_path"][0] object_name = batch_data["object_name"][0] preds[frame_path] = {object_name: None} print("No graspable points found at frame: ", frame_path) continue grasp_preds = self.decode_pred(end_points) standard_grasp_preds = GSNetPreprocessor.standard_pred_decode(end_points) standard_preds = standard_grasp_preds[0].detach().cpu().numpy() if require_gripper: gg = GraspGroup(standard_preds) gg = gg.nms() gg = gg.sort_by_score() grippers = gg.to_open3d_geometry_list() gp_pts_list = np.asarray([np.asarray(gripper_mesh.sample_points_uniformly(48).points) for gripper_mesh in grippers], dtype=np.float16) gp_score_list = gg.scores for idx in range(len(batch_data["frame_path"])): frame_path = batch_data["frame_path"][idx] object_name = batch_data["object_name"][idx] if frame_path not in preds: preds[frame_path] = {object_name: {}} preds[frame_path][object_name] = grasp_preds[idx] preds[frame_path][object_name]["obj_pcl_dict"] = ( batch_data["obj_pcl_dict"][idx] ) if require_gripper: preds[frame_path][object_name]["gripper"] = { "gripper_pose": gp_pts_list.tolist(), "gripper_score": gp_score_list.tolist() } except Exception as e: print("Error in inference: ", e) # ----- Debug Trace ----- # print(batch_data["frame_path"]) import ipdb; ipdb.set_trace() frame_path = batch_data["frame_path"][idx] object_name = batch_data["object_name"][idx] preds[frame_path] = {object_name: {}} # ------------------------ # results = {} for frame_path in preds: try: predict_results = {} for object_name in preds[frame_path]: if object_name is None or preds[frame_path][object_name] == None: continue grasp_center = preds[frame_path][object_name]["grasp_center"] grasp_score = preds[frame_path][object_name]["grasp_score"] obj_pcl_dict = preds[frame_path][object_name]["obj_pcl_dict"] if require_gripper: gripper = preds[frame_path][object_name]["gripper"] grasp_center = grasp_center.unsqueeze(1) obj_pcl = obj_pcl_dict[object_name] obj_pcl = torch.tensor( obj_pcl.astype(np.float32), device=grasp_center.device ) obj_pcl = obj_pcl.unsqueeze(0) grasp_obj_table = (grasp_center == obj_pcl).all(axis=-1) obj_pts_on_grasp = grasp_obj_table.any(axis=1) obj_graspable_pts = grasp_center[obj_pts_on_grasp].squeeze(1) obj_graspable_pts_score = grasp_score[obj_pts_on_grasp] obj_graspable_pts_info = torch.cat( [obj_graspable_pts, obj_graspable_pts_score], dim=1 ) if obj_graspable_pts.shape[0] == 0: obj_graspable_pts_info = torch.zeros((top_k, 4)) ranked_obj_graspable_pts_info = self.sample_graspable_pts( obj_graspable_pts_info, top_k=top_k ) predict_results[object_name] = { "positions": ranked_obj_graspable_pts_info[:, :3] .cpu() .numpy() .tolist(), "scores": ranked_obj_graspable_pts_info[:, 3] .cpu() .numpy() .tolist(), } if require_gripper: results[frame_path] = {"predicted_results": predict_results, "gripper": gripper} else: results[frame_path] = {"predicted_results": predict_results} except Exception as e: print("Error in postprocessing: ", e) # ----- Debug Trace ----- # print(frame_path) import ipdb; ipdb.set_trace() # ------------------------ # print("Prediction finished") return results @staticmethod def sample_graspable_pts(graspable_pts, top_k=50): if graspable_pts.shape[0] < top_k: sampled_indices = torch.randint(0, graspable_pts.shape[0], (top_k,)) graspable_pts = graspable_pts[sampled_indices] sorted_indices = torch.argsort(graspable_pts[:, 3], descending=True) sampled_indices = graspable_pts[sorted_indices][:top_k] return sampled_indices def decode_pred(self, end_points): batch_size = len(end_points["point_clouds"]) grasp_preds = [] for i in range(batch_size): grasp_center = end_points["xyz_graspable"][i].float() num_pts = end_points["xyz_graspable"][i].shape[0] grasp_score = end_points["grasp_score_pred"][i].float() grasp_score = grasp_score.view(num_pts, -1) grasp_score, _ = torch.max(grasp_score, -1) # [M_POINT] grasp_score = grasp_score.view(-1, 1) grasp_preds.append( {"grasp_center": grasp_center, "grasp_score": grasp_score} ) return grasp_preds @staticmethod def standard_pred_decode(end_points): batch_size = len(end_points['point_clouds']) grasp_preds = [] for i in range(batch_size): grasp_center = end_points['xyz_graspable'][i].float() num_pts = end_points["xyz_graspable"][i].shape[0] grasp_score = end_points['grasp_score_pred'][i].float() grasp_score = grasp_score.view(num_pts, -1) grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT] grasp_score = grasp_score.view(-1, 1) grasp_angle = (grasp_score_inds // GSNetPreprocessor.NUM_DEPTH) * np.pi / 12 grasp_depth = (grasp_score_inds % GSNetPreprocessor.NUM_DEPTH + 1) * 0.01 grasp_depth = grasp_depth.view(-1, 1) grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10. grasp_width = grasp_width.view(GSNetPreprocessor.M_POINT, GSNetPreprocessor.NUM_ANGLE*GSNetPreprocessor.NUM_DEPTH) grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1)) grasp_width = torch.clamp(grasp_width, min=0., max=GSNetPreprocessor.GRASP_MAX_WIDTH) approaching = -end_points['grasp_top_view_xyz'][i].float() grasp_rot = GSNetPreprocessor.batch_viewpoint_params_to_matrix(approaching, grasp_angle) grasp_rot = grasp_rot.view(GSNetPreprocessor.M_POINT, 9) # merge preds grasp_height = 0.02 * torch.ones_like(grasp_score) obj_ids = -1 * torch.ones_like(grasp_score) grasp_preds.append( torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1)) return grasp_preds @staticmethod def batch_viewpoint_params_to_matrix(batch_towards, batch_angle): axis_x = batch_towards ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) axis_y = torch.stack([-axis_x[:, 1], axis_x[:, 0], zeros], dim=-1) mask_y = (torch.norm(axis_y, dim=-1) == 0) axis_y[mask_y, 1] = 1 axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True) axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True) axis_z = torch.cross(axis_x, axis_y) sin = torch.sin(batch_angle) cos = torch.cos(batch_angle) R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1) R1 = R1.reshape([-1, 3, 3]) R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1) batch_matrix = torch.matmul(R2, R1) return batch_matrix if __name__ == "__main__": gs_preproc = GSNetPreprocessor(config_path="configs/server_gsnet_preprocess_config.yaml") gs_preproc.run()