409 lines
19 KiB
Python
Executable File
409 lines
19 KiB
Python
Executable File
import os
|
|
import re
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
import open3d as o3d
|
|
from torch.utils.data import DataLoader
|
|
|
|
path = os.path.abspath(__file__)
|
|
for i in range(4):
|
|
path = os.path.dirname(path)
|
|
PROJECT_ROOT = path
|
|
sys.path.append(PROJECT_ROOT)
|
|
GSNET_PROJECT_ROOT = os.path.join(PROJECT_ROOT, "baselines/grasping/GSNet")
|
|
sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "pointnet2"))
|
|
sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "utils"))
|
|
sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "models"))
|
|
sys.path.append(os.path.join(GSNET_PROJECT_ROOT, "dataset"))
|
|
|
|
from utils.omni_util import OmniUtil
|
|
from utils.view_util import ViewUtil
|
|
from runners.preprocessors.grasping.abstract_grasping_preprocessor import GraspingPreprocessor
|
|
from configs.config import ConfigManager
|
|
|
|
from baselines.grasping.GSNet.models.graspnet import GraspNet
|
|
from baselines.grasping.GSNet.graspnetAPI.graspnetAPI.graspnet_eval import GraspGroup
|
|
from baselines.grasping.GSNet.dataset.graspnet_dataset import minkowski_collate_fn
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class GSNetInferenceDataset(Dataset):
|
|
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
|
|
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
|
|
RGB_TEMPLATE = "rgb_{}.png"
|
|
MASK_TEMPLATE = "semantic_segmentation_{}.png"
|
|
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
|
|
|
|
def __init__(
|
|
self,
|
|
source="nbv1",
|
|
data_type="sample",
|
|
data_dir="/mnt/h/AI/Datasets",
|
|
scene_pts_num=15000,
|
|
voxel_size=0.005,
|
|
):
|
|
|
|
self.data_dir = data_dir
|
|
self.scene_pts_num = scene_pts_num
|
|
self.data_path = str(os.path.join(self.data_dir, source, data_type))
|
|
self.scene_list = os.listdir(self.data_path)
|
|
self.data_list = self.get_datalist()
|
|
self.voxel_size = voxel_size
|
|
|
|
def __len__(self):
|
|
return len(self.data_list)
|
|
|
|
def __getitem__(self, index):
|
|
frame_path, target = self.data_list[index]
|
|
frame_data = self.load_frame_data(frame_path=frame_path, object_name=target)
|
|
return frame_data
|
|
|
|
def get_datalist(self):
|
|
scene_frame_list = []
|
|
for scene in self.scene_list:
|
|
scene_path = os.path.join(self.data_path, scene)
|
|
file_list = os.listdir(scene_path)
|
|
for file in file_list:
|
|
if file.startswith("camera_params"):
|
|
frame_index = re.findall(r"\d+", file)[0]
|
|
frame_path = os.path.join(scene_path, frame_index)
|
|
target_list = OmniUtil.get_object_list(frame_path)
|
|
for target in target_list:
|
|
scene_frame_list.append((frame_path,target))
|
|
if len(target_list) == 0:
|
|
scene_frame_list.append((frame_path, None))
|
|
print("Scene: ", scene, " has ", len(scene_frame_list), " frames")
|
|
return scene_frame_list
|
|
|
|
def load_frame_data(self, frame_path, object_name):
|
|
try:
|
|
target_list = OmniUtil.get_object_list(path=frame_path, contains_non_obj=True)
|
|
_, obj_pcl_dict = OmniUtil.get_segmented_points(
|
|
path=frame_path, target_list=target_list
|
|
)
|
|
obj_center = ViewUtil.get_object_center_from_pts_dict(object_name, obj_pcl_dict)
|
|
croped_pts_dict = ViewUtil.crop_pts_dict(obj_pcl_dict, obj_center, radius=0.2)
|
|
sampled_scene_pts, sampled_pts_dict = GSNetInferenceDataset.sample_dict_to_target_points(croped_pts_dict)
|
|
ret_dict = {
|
|
"frame_path": frame_path,
|
|
"point_clouds": sampled_scene_pts.astype(np.float32),
|
|
"coors": sampled_scene_pts.astype(np.float32) / self.voxel_size,
|
|
"feats": np.ones_like(sampled_scene_pts).astype(np.float32),
|
|
"obj_pcl_dict": sampled_pts_dict,
|
|
"object_name": object_name,
|
|
}
|
|
except Exception as e:
|
|
print("Error in loading frame data: ", e)
|
|
ret_dict = {
|
|
"frame_path": frame_path,
|
|
"point_clouds": np.zeros((self.scene_pts_num, 3)).astype(np.float32),
|
|
"coors": np.zeros((self.scene_pts_num, 3)).astype(np.float32),
|
|
"feats": np.ones((self.scene_pts_num, 3)).astype(np.float32),
|
|
"obj_pcl_dict": {},
|
|
"object_name": object_name,
|
|
"error": True
|
|
}
|
|
return ret_dict
|
|
|
|
def sample_points(points, target_num_points):
|
|
num_points = points.shape[0]
|
|
if num_points == 0:
|
|
return np.zeros((target_num_points, points.shape[1]))
|
|
if num_points > target_num_points:
|
|
indices = np.random.choice(num_points, target_num_points, replace=False)
|
|
else:
|
|
indices = np.random.choice(num_points, target_num_points, replace=True)
|
|
return points[indices]
|
|
|
|
def sample_dict_to_target_points(croped_pts_dict, total_points=15000):
|
|
all_sampled_points = []
|
|
sampled_pts_dict = {}
|
|
total_existing_points = sum([pts.shape[0] for pts in croped_pts_dict.values() if pts.shape[0] > 0])
|
|
|
|
if total_existing_points > total_points:
|
|
ratios = {name: len(pts) / total_existing_points for name, pts in croped_pts_dict.items() if pts.shape[0] > 0}
|
|
target_num_points = {name: int(ratio * total_points) for name, ratio in ratios.items()}
|
|
remaining_points = total_points - sum(target_num_points.values())
|
|
for name in target_num_points.keys():
|
|
if remaining_points > 0:
|
|
target_num_points[name] += 1
|
|
remaining_points -= 1
|
|
else:
|
|
target_num_points = {name: len(pts) for name, pts in croped_pts_dict.items()}
|
|
remaining_points = total_points - total_existing_points
|
|
additional_points = np.random.choice([name for name, pts in croped_pts_dict.items() if pts.shape[0] > 0], remaining_points, replace=True)
|
|
for name in additional_points:
|
|
target_num_points[name] += 1
|
|
|
|
for name, pts in croped_pts_dict.items():
|
|
if pts.shape[0] == 0:
|
|
sampled_pts_dict[name] = pts
|
|
continue
|
|
sampled_pts = GSNetInferenceDataset.sample_points(pts, target_num_points[name])
|
|
sampled_pts_dict[name] = sampled_pts
|
|
all_sampled_points.append(sampled_pts)
|
|
|
|
if len(all_sampled_points) > 0:
|
|
sampled_scene_pts = np.concatenate(all_sampled_points, axis=0)
|
|
else:
|
|
sampled_scene_pts = np.zeros((total_points, 3))
|
|
return sampled_scene_pts, sampled_pts_dict
|
|
|
|
@staticmethod
|
|
def sample_pcl(pcl, n_pts=1024):
|
|
indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts)
|
|
return pcl[indices, :]
|
|
|
|
|
|
class GSNetPreprocessor(GraspingPreprocessor):
|
|
GRASP_MAX_WIDTH = 0.1
|
|
GRASPNESS_THRESHOLD = 0.1
|
|
NUM_VIEW = 300
|
|
NUM_ANGLE = 12
|
|
NUM_DEPTH = 4
|
|
M_POINT = 1024
|
|
|
|
def __init__(self, config_path):
|
|
super().__init__(config_path)
|
|
|
|
def get_dataloader(self, dataset_config):
|
|
def my_worker_init_fn(worker_id):
|
|
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
|
|
dataset = GSNetInferenceDataset(
|
|
source=dataset_config["source"],
|
|
data_type=dataset_config["data_type"],
|
|
data_dir=dataset_config["data_dir"],
|
|
scene_pts_num=dataset_config["scene_pts_num"],
|
|
voxel_size=dataset_config["voxel_size"],
|
|
)
|
|
print("Test dataset length: ", len(dataset))
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=dataset_config["batch_size"],
|
|
shuffle=False,
|
|
num_workers=0,
|
|
worker_init_fn=my_worker_init_fn,
|
|
collate_fn=minkowski_collate_fn,
|
|
)
|
|
print("Test dataloader length: ", len(dataloader))
|
|
return dataloader
|
|
|
|
def get_model(self, model_config=None):
|
|
model = GraspNet(seed_feat_dim=model_config["general"]["seed_feat_dim"], is_training=False)
|
|
model.to("cuda")
|
|
checkpoint = torch.load(model_config["general"]["checkpoint_path"])
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
start_epoch = checkpoint["epoch"]
|
|
print(
|
|
"-> loaded checkpoint %s (epoch: %d)" % (model_config["general"]["checkpoint_path"], start_epoch)
|
|
)
|
|
model.eval()
|
|
return model
|
|
|
|
def prediction(self, model, dataloader, require_gripper=False, top_k=10):
|
|
preds = {}
|
|
|
|
for idx, batch_data in enumerate(dataloader):
|
|
try:
|
|
if "error" in batch_data:
|
|
frame_path = batch_data["frame_path"][0]
|
|
object_name = batch_data["object_name"][0]
|
|
preds[frame_path] = {object_name: None}
|
|
print("No graspable points found at frame: ", frame_path)
|
|
continue
|
|
print("Processing batch: ", idx, "/", len(dataloader))
|
|
for key in batch_data:
|
|
if "list" in key:
|
|
for i in range(len(batch_data[key])):
|
|
for j in range(len(batch_data[key][i])):
|
|
batch_data[key][i][j] = batch_data[key][i][j].to("cuda")
|
|
elif not isinstance(batch_data[key], (list)):
|
|
batch_data[key] = batch_data[key].to("cuda")
|
|
with torch.no_grad():
|
|
|
|
end_points = model(batch_data)
|
|
if end_points is None:
|
|
frame_path = batch_data["frame_path"][0]
|
|
object_name = batch_data["object_name"][0]
|
|
preds[frame_path] = {object_name: None}
|
|
print("No graspable points found at frame: ", frame_path)
|
|
continue
|
|
grasp_preds = self.decode_pred(end_points)
|
|
|
|
standard_grasp_preds = GSNetPreprocessor.standard_pred_decode(end_points)
|
|
standard_preds = standard_grasp_preds[0].detach().cpu().numpy()
|
|
if require_gripper:
|
|
gg = GraspGroup(standard_preds)
|
|
gg = gg.nms()
|
|
gg = gg.sort_by_score()
|
|
grippers = gg.to_open3d_geometry_list()
|
|
gp_pts_list = np.asarray([np.asarray(gripper_mesh.sample_points_uniformly(48).points) for gripper_mesh in grippers], dtype=np.float16)
|
|
gp_score_list = gg.scores
|
|
|
|
for idx in range(len(batch_data["frame_path"])):
|
|
frame_path = batch_data["frame_path"][idx]
|
|
object_name = batch_data["object_name"][idx]
|
|
if frame_path not in preds:
|
|
preds[frame_path] = {object_name: {}}
|
|
|
|
preds[frame_path][object_name] = grasp_preds[idx]
|
|
preds[frame_path][object_name]["obj_pcl_dict"] = (
|
|
batch_data["obj_pcl_dict"][idx]
|
|
)
|
|
if require_gripper:
|
|
preds[frame_path][object_name]["gripper"] = {
|
|
"gripper_pose": gp_pts_list.tolist(),
|
|
"gripper_score": gp_score_list.tolist()
|
|
}
|
|
except Exception as e:
|
|
print("Error in inference: ", e)
|
|
# ----- Debug Trace ----- #
|
|
print(batch_data["frame_path"])
|
|
import ipdb; ipdb.set_trace()
|
|
frame_path = batch_data["frame_path"][idx]
|
|
object_name = batch_data["object_name"][idx]
|
|
preds[frame_path] = {object_name: {}}
|
|
# ------------------------ #
|
|
|
|
|
|
results = {}
|
|
for frame_path in preds:
|
|
try:
|
|
predict_results = {}
|
|
for object_name in preds[frame_path]:
|
|
if object_name is None or preds[frame_path][object_name] == None:
|
|
continue
|
|
grasp_center = preds[frame_path][object_name]["grasp_center"]
|
|
grasp_score = preds[frame_path][object_name]["grasp_score"]
|
|
obj_pcl_dict = preds[frame_path][object_name]["obj_pcl_dict"]
|
|
if require_gripper:
|
|
gripper = preds[frame_path][object_name]["gripper"]
|
|
grasp_center = grasp_center.unsqueeze(1)
|
|
obj_pcl = obj_pcl_dict[object_name]
|
|
obj_pcl = torch.tensor(
|
|
obj_pcl.astype(np.float32), device=grasp_center.device
|
|
)
|
|
obj_pcl = obj_pcl.unsqueeze(0)
|
|
grasp_obj_table = (grasp_center == obj_pcl).all(axis=-1)
|
|
obj_pts_on_grasp = grasp_obj_table.any(axis=1)
|
|
obj_graspable_pts = grasp_center[obj_pts_on_grasp].squeeze(1)
|
|
|
|
|
|
|
|
obj_graspable_pts_score = grasp_score[obj_pts_on_grasp]
|
|
obj_graspable_pts_info = torch.cat(
|
|
[obj_graspable_pts, obj_graspable_pts_score], dim=1
|
|
)
|
|
|
|
if obj_graspable_pts.shape[0] == 0:
|
|
obj_graspable_pts_info = torch.zeros((top_k, 4))
|
|
ranked_obj_graspable_pts_info = self.sample_graspable_pts(
|
|
obj_graspable_pts_info, top_k=top_k
|
|
)
|
|
predict_results[object_name] = {
|
|
"positions": ranked_obj_graspable_pts_info[:, :3]
|
|
.cpu()
|
|
.numpy()
|
|
.tolist(),
|
|
"scores": ranked_obj_graspable_pts_info[:, 3]
|
|
.cpu()
|
|
.numpy()
|
|
.tolist(),
|
|
|
|
}
|
|
if require_gripper:
|
|
results[frame_path] = {"predicted_results": predict_results, "gripper": gripper}
|
|
else:
|
|
results[frame_path] = {"predicted_results": predict_results}
|
|
|
|
except Exception as e:
|
|
print("Error in postprocessing: ", e)
|
|
# ----- Debug Trace ----- #
|
|
print(frame_path)
|
|
import ipdb; ipdb.set_trace()
|
|
# ------------------------ #
|
|
|
|
print("Prediction finished")
|
|
return results
|
|
|
|
|
|
|
|
@staticmethod
|
|
def sample_graspable_pts(graspable_pts, top_k=50):
|
|
if graspable_pts.shape[0] < top_k:
|
|
sampled_indices = torch.randint(0, graspable_pts.shape[0], (top_k,))
|
|
graspable_pts = graspable_pts[sampled_indices]
|
|
sorted_indices = torch.argsort(graspable_pts[:, 3], descending=True)
|
|
sampled_indices = graspable_pts[sorted_indices][:top_k]
|
|
return sampled_indices
|
|
|
|
def decode_pred(self, end_points):
|
|
batch_size = len(end_points["point_clouds"])
|
|
grasp_preds = []
|
|
for i in range(batch_size):
|
|
grasp_center = end_points["xyz_graspable"][i].float()
|
|
num_pts = end_points["xyz_graspable"][i].shape[0]
|
|
grasp_score = end_points["grasp_score_pred"][i].float()
|
|
grasp_score = grasp_score.view(num_pts, -1)
|
|
grasp_score, _ = torch.max(grasp_score, -1) # [M_POINT]
|
|
grasp_score = grasp_score.view(-1, 1)
|
|
grasp_preds.append(
|
|
{"grasp_center": grasp_center, "grasp_score": grasp_score}
|
|
)
|
|
return grasp_preds
|
|
|
|
@staticmethod
|
|
def standard_pred_decode(end_points):
|
|
batch_size = len(end_points['point_clouds'])
|
|
grasp_preds = []
|
|
for i in range(batch_size):
|
|
grasp_center = end_points['xyz_graspable'][i].float()
|
|
num_pts = end_points["xyz_graspable"][i].shape[0]
|
|
grasp_score = end_points['grasp_score_pred'][i].float()
|
|
grasp_score = grasp_score.view(num_pts, -1)
|
|
grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT]
|
|
grasp_score = grasp_score.view(-1, 1)
|
|
grasp_angle = (grasp_score_inds // GSNetPreprocessor.NUM_DEPTH) * np.pi / 12
|
|
grasp_depth = (grasp_score_inds % GSNetPreprocessor.NUM_DEPTH + 1) * 0.01
|
|
grasp_depth = grasp_depth.view(-1, 1)
|
|
grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10.
|
|
grasp_width = grasp_width.view(GSNetPreprocessor.M_POINT, GSNetPreprocessor.NUM_ANGLE*GSNetPreprocessor.NUM_DEPTH)
|
|
grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1))
|
|
grasp_width = torch.clamp(grasp_width, min=0., max=GSNetPreprocessor.GRASP_MAX_WIDTH)
|
|
|
|
approaching = -end_points['grasp_top_view_xyz'][i].float()
|
|
grasp_rot = GSNetPreprocessor.batch_viewpoint_params_to_matrix(approaching, grasp_angle)
|
|
grasp_rot = grasp_rot.view(GSNetPreprocessor.M_POINT, 9)
|
|
|
|
# merge preds
|
|
grasp_height = 0.02 * torch.ones_like(grasp_score)
|
|
obj_ids = -1 * torch.ones_like(grasp_score)
|
|
grasp_preds.append(
|
|
torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1))
|
|
return grasp_preds
|
|
|
|
@staticmethod
|
|
def batch_viewpoint_params_to_matrix(batch_towards, batch_angle):
|
|
axis_x = batch_towards
|
|
ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
|
|
zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
|
|
axis_y = torch.stack([-axis_x[:, 1], axis_x[:, 0], zeros], dim=-1)
|
|
mask_y = (torch.norm(axis_y, dim=-1) == 0)
|
|
axis_y[mask_y, 1] = 1
|
|
axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True)
|
|
axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True)
|
|
axis_z = torch.cross(axis_x, axis_y)
|
|
sin = torch.sin(batch_angle)
|
|
cos = torch.cos(batch_angle)
|
|
R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1)
|
|
R1 = R1.reshape([-1, 3, 3])
|
|
R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1)
|
|
batch_matrix = torch.matmul(R2, R1)
|
|
return batch_matrix
|
|
|
|
|
|
if __name__ == "__main__":
|
|
gs_preproc = GSNetPreprocessor(config_path="configs/server_gsnet_preprocess_config.yaml")
|
|
gs_preproc.run() |