127 lines
5.9 KiB
Python
Executable File
127 lines
5.9 KiB
Python
Executable File
""" GraspNet baseline model definition.
|
|
Author: chenxi-wang
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import MinkowskiEngine as ME
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
ROOT_DIR = os.path.dirname(BASE_DIR)
|
|
sys.path.append(ROOT_DIR)
|
|
|
|
from models.backbone_resunet14 import MinkUNet14D
|
|
from models.modules import ApproachNet, GraspableNet, CloudCrop, SWADNet
|
|
from loss_utils import GRASP_MAX_WIDTH, NUM_VIEW, NUM_ANGLE, NUM_DEPTH, GRASPNESS_THRESHOLD, M_POINT
|
|
from label_generation import process_grasp_labels, match_grasp_view_and_label, batch_viewpoint_params_to_matrix
|
|
from pointnet2.pointnet2_utils import furthest_point_sample, gather_operation
|
|
|
|
|
|
class GraspNet(nn.Module):
|
|
def __init__(self, cylinder_radius=0.05, seed_feat_dim=512, is_training=True):
|
|
super().__init__()
|
|
self.is_training = is_training
|
|
self.seed_feature_dim = seed_feat_dim
|
|
self.num_depth = NUM_DEPTH
|
|
self.num_angle = NUM_ANGLE
|
|
self.M_points = M_POINT
|
|
self.num_view = NUM_VIEW
|
|
|
|
self.backbone = MinkUNet14D(in_channels=3, out_channels=self.seed_feature_dim, D=3)
|
|
self.graspable = GraspableNet(seed_feature_dim=self.seed_feature_dim)
|
|
self.rotation = ApproachNet(self.num_view, seed_feature_dim=self.seed_feature_dim, is_training=self.is_training)
|
|
self.crop = CloudCrop(nsample=16, cylinder_radius=cylinder_radius, seed_feature_dim=self.seed_feature_dim)
|
|
self.swad = SWADNet(num_angle=self.num_angle, num_depth=self.num_depth)
|
|
|
|
def forward(self, end_points):
|
|
seed_xyz = end_points['point_clouds'] # use all sampled point cloud, B*Ns*3
|
|
B, point_num, _ = seed_xyz.shape # batch _size
|
|
# point-wise features
|
|
coordinates_batch = end_points['coors']
|
|
features_batch = end_points['feats']
|
|
mink_input = ME.SparseTensor(features_batch, coordinates=coordinates_batch)
|
|
seed_features = self.backbone(mink_input).F
|
|
seed_features = seed_features[end_points['quantize2original']].view(B, point_num, -1).transpose(1, 2)
|
|
|
|
end_points = self.graspable(seed_features, end_points)
|
|
seed_features_flipped = seed_features.transpose(1, 2) # B*Ns*feat_dim
|
|
objectness_score = end_points['objectness_score']
|
|
graspness_score = end_points['graspness_score'].squeeze(1)
|
|
objectness_pred = torch.argmax(objectness_score, 1)
|
|
objectness_mask = (objectness_pred == 1)
|
|
graspness_mask = graspness_score > GRASPNESS_THRESHOLD
|
|
graspable_mask = objectness_mask & graspness_mask
|
|
|
|
seed_features_graspable = []
|
|
seed_xyz_graspable = []
|
|
graspable_num_batch = 0.
|
|
for i in range(B):
|
|
cur_mask = graspable_mask[i]
|
|
graspable_num_batch += cur_mask.sum()
|
|
if graspable_num_batch == 0:
|
|
return None
|
|
cur_feat = seed_features_flipped[i][cur_mask] # Ns*feat_dim
|
|
cur_seed_xyz = seed_xyz[i][cur_mask] # Ns*3
|
|
|
|
cur_seed_xyz = cur_seed_xyz.unsqueeze(0) # 1*Ns*3
|
|
fps_idxs = furthest_point_sample(cur_seed_xyz, self.M_points)
|
|
cur_seed_xyz_flipped = cur_seed_xyz.transpose(1, 2).contiguous() # 1*3*Ns
|
|
cur_seed_xyz = gather_operation(cur_seed_xyz_flipped, fps_idxs).transpose(1, 2).squeeze(0).contiguous() # Ns*3
|
|
cur_feat_flipped = cur_feat.unsqueeze(0).transpose(1, 2).contiguous() # 1*feat_dim*Ns
|
|
cur_feat = gather_operation(cur_feat_flipped, fps_idxs).squeeze(0).contiguous() # feat_dim*Ns
|
|
|
|
seed_features_graspable.append(cur_feat)
|
|
seed_xyz_graspable.append(cur_seed_xyz)
|
|
seed_xyz_graspable = torch.stack(seed_xyz_graspable, 0) # B*Ns*3
|
|
seed_features_graspable = torch.stack(seed_features_graspable) # B*feat_dim*Ns
|
|
|
|
end_points['xyz_graspable'] = seed_xyz_graspable
|
|
end_points['graspable_count_stage1'] = graspable_num_batch / B
|
|
|
|
end_points, res_feat = self.rotation(seed_features_graspable, end_points)
|
|
seed_features_graspable = seed_features_graspable + res_feat
|
|
|
|
if self.is_training:
|
|
end_points = process_grasp_labels(end_points)
|
|
grasp_top_views_rot, end_points = match_grasp_view_and_label(end_points)
|
|
else:
|
|
grasp_top_views_rot = end_points['grasp_top_view_rot']
|
|
|
|
group_features = self.crop(seed_xyz_graspable.contiguous(), seed_features_graspable.contiguous(), grasp_top_views_rot)
|
|
end_points = self.swad(group_features, end_points)
|
|
|
|
return end_points
|
|
|
|
|
|
def pred_decode(end_points):
|
|
batch_size = len(end_points['point_clouds'])
|
|
grasp_preds = []
|
|
for i in range(batch_size):
|
|
grasp_center = end_points['xyz_graspable'][i].float()
|
|
|
|
grasp_score = end_points['grasp_score_pred'][i].float()
|
|
grasp_score = grasp_score.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
|
|
grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT]
|
|
grasp_score = grasp_score.view(-1, 1)
|
|
grasp_angle = (grasp_score_inds // NUM_DEPTH) * np.pi / 12
|
|
grasp_depth = (grasp_score_inds % NUM_DEPTH + 1) * 0.01
|
|
grasp_depth = grasp_depth.view(-1, 1)
|
|
grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10.
|
|
grasp_width = grasp_width.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
|
|
grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1))
|
|
grasp_width = torch.clamp(grasp_width, min=0., max=GRASP_MAX_WIDTH)
|
|
|
|
approaching = -end_points['grasp_top_view_xyz'][i].float()
|
|
grasp_rot = batch_viewpoint_params_to_matrix(approaching, grasp_angle)
|
|
grasp_rot = grasp_rot.view(M_POINT, 9)
|
|
|
|
# merge preds
|
|
grasp_height = 0.02 * torch.ones_like(grasp_score)
|
|
obj_ids = -1 * torch.ones_like(grasp_score)
|
|
grasp_preds.append(
|
|
torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1))
|
|
return grasp_preds
|