127 lines
5.9 KiB
Python
Raw Normal View History

2024-10-09 16:13:22 +00:00
""" GraspNet baseline model definition.
Author: chenxi-wang
"""
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import MinkowskiEngine as ME
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)
from models.backbone_resunet14 import MinkUNet14D
from models.modules import ApproachNet, GraspableNet, CloudCrop, SWADNet
from loss_utils import GRASP_MAX_WIDTH, NUM_VIEW, NUM_ANGLE, NUM_DEPTH, GRASPNESS_THRESHOLD, M_POINT
from label_generation import process_grasp_labels, match_grasp_view_and_label, batch_viewpoint_params_to_matrix
from pointnet2.pointnet2_utils import furthest_point_sample, gather_operation
class GraspNet(nn.Module):
def __init__(self, cylinder_radius=0.05, seed_feat_dim=512, is_training=True):
super().__init__()
self.is_training = is_training
self.seed_feature_dim = seed_feat_dim
self.num_depth = NUM_DEPTH
self.num_angle = NUM_ANGLE
self.M_points = M_POINT
self.num_view = NUM_VIEW
self.backbone = MinkUNet14D(in_channels=3, out_channels=self.seed_feature_dim, D=3)
self.graspable = GraspableNet(seed_feature_dim=self.seed_feature_dim)
self.rotation = ApproachNet(self.num_view, seed_feature_dim=self.seed_feature_dim, is_training=self.is_training)
self.crop = CloudCrop(nsample=16, cylinder_radius=cylinder_radius, seed_feature_dim=self.seed_feature_dim)
self.swad = SWADNet(num_angle=self.num_angle, num_depth=self.num_depth)
def forward(self, end_points):
seed_xyz = end_points['point_clouds'] # use all sampled point cloud, B*Ns*3
B, point_num, _ = seed_xyz.shape # batch _size
# point-wise features
coordinates_batch = end_points['coors']
features_batch = end_points['feats']
mink_input = ME.SparseTensor(features_batch, coordinates=coordinates_batch)
seed_features = self.backbone(mink_input).F
seed_features = seed_features[end_points['quantize2original']].view(B, point_num, -1).transpose(1, 2)
end_points = self.graspable(seed_features, end_points)
seed_features_flipped = seed_features.transpose(1, 2) # B*Ns*feat_dim
objectness_score = end_points['objectness_score']
graspness_score = end_points['graspness_score'].squeeze(1)
objectness_pred = torch.argmax(objectness_score, 1)
objectness_mask = (objectness_pred == 1)
graspness_mask = graspness_score > GRASPNESS_THRESHOLD
graspable_mask = objectness_mask & graspness_mask
seed_features_graspable = []
seed_xyz_graspable = []
graspable_num_batch = 0.
for i in range(B):
cur_mask = graspable_mask[i]
graspable_num_batch += cur_mask.sum()
if graspable_num_batch == 0:
return None
cur_feat = seed_features_flipped[i][cur_mask] # Ns*feat_dim
cur_seed_xyz = seed_xyz[i][cur_mask] # Ns*3
cur_seed_xyz = cur_seed_xyz.unsqueeze(0) # 1*Ns*3
fps_idxs = furthest_point_sample(cur_seed_xyz, self.M_points)
cur_seed_xyz_flipped = cur_seed_xyz.transpose(1, 2).contiguous() # 1*3*Ns
cur_seed_xyz = gather_operation(cur_seed_xyz_flipped, fps_idxs).transpose(1, 2).squeeze(0).contiguous() # Ns*3
cur_feat_flipped = cur_feat.unsqueeze(0).transpose(1, 2).contiguous() # 1*feat_dim*Ns
cur_feat = gather_operation(cur_feat_flipped, fps_idxs).squeeze(0).contiguous() # feat_dim*Ns
seed_features_graspable.append(cur_feat)
seed_xyz_graspable.append(cur_seed_xyz)
seed_xyz_graspable = torch.stack(seed_xyz_graspable, 0) # B*Ns*3
seed_features_graspable = torch.stack(seed_features_graspable) # B*feat_dim*Ns
end_points['xyz_graspable'] = seed_xyz_graspable
end_points['graspable_count_stage1'] = graspable_num_batch / B
end_points, res_feat = self.rotation(seed_features_graspable, end_points)
seed_features_graspable = seed_features_graspable + res_feat
if self.is_training:
end_points = process_grasp_labels(end_points)
grasp_top_views_rot, end_points = match_grasp_view_and_label(end_points)
else:
grasp_top_views_rot = end_points['grasp_top_view_rot']
group_features = self.crop(seed_xyz_graspable.contiguous(), seed_features_graspable.contiguous(), grasp_top_views_rot)
end_points = self.swad(group_features, end_points)
return end_points
def pred_decode(end_points):
batch_size = len(end_points['point_clouds'])
grasp_preds = []
for i in range(batch_size):
grasp_center = end_points['xyz_graspable'][i].float()
grasp_score = end_points['grasp_score_pred'][i].float()
grasp_score = grasp_score.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT]
grasp_score = grasp_score.view(-1, 1)
grasp_angle = (grasp_score_inds // NUM_DEPTH) * np.pi / 12
grasp_depth = (grasp_score_inds % NUM_DEPTH + 1) * 0.01
grasp_depth = grasp_depth.view(-1, 1)
grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10.
grasp_width = grasp_width.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1))
grasp_width = torch.clamp(grasp_width, min=0., max=GRASP_MAX_WIDTH)
approaching = -end_points['grasp_top_view_xyz'][i].float()
grasp_rot = batch_viewpoint_params_to_matrix(approaching, grasp_angle)
grasp_rot = grasp_rot.view(M_POINT, 9)
# merge preds
grasp_height = 0.02 * torch.ones_like(grasp_score)
obj_ids = -1 * torch.ones_like(grasp_score)
grasp_preds.append(
torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1))
return grasp_preds