sim on omni_test
This commit is contained in:
@@ -12,7 +12,7 @@ settings:
|
||||
experiment:
|
||||
name: test_inference
|
||||
root_dir: "experiments"
|
||||
model_path: "/media/hofee/data/weights/nbv_grasping/full_149_241009.pth"
|
||||
model_path: "/media/hofee/data/weights/nbv_grasping/grasping_full_centralized_ood_111.pth"
|
||||
use_cache: True
|
||||
small_batch_overfit: False
|
||||
|
||||
|
@@ -108,8 +108,10 @@ class ActivePerceptionSingleViewPolicy(SingleViewPolicy):
|
||||
support_id,
|
||||
scene_sample_num=16384,
|
||||
target_sample_num=1024)
|
||||
ap_input = {'target_pts': self.target_points,
|
||||
'scene_pts': self.scene_points}
|
||||
target_points_center = torch.mean(self.target_points, dim=1)
|
||||
print("target_points_center", target_points_center)
|
||||
ap_input = {'target_pts': self.target_points - target_points_center,
|
||||
'scene_pts': self.scene_points - target_points_center}
|
||||
self.scene_points_cache = self.scene_points.cpu().numpy()[0]
|
||||
self.publish_pointcloud(self.scene_points_cache)
|
||||
ap_output = self.ap_inference_engine.inference(ap_input)
|
||||
@@ -125,26 +127,32 @@ class ActivePerceptionSingleViewPolicy(SingleViewPolicy):
|
||||
look_at_center_world = (src_cam_to_world_mat.cpu().numpy() @ look_at_center_cam_homogeneous)[:3]
|
||||
look_at_center_world = torch.from_numpy(look_at_center_world).float().to("cuda:0")
|
||||
# Get the NBV
|
||||
dst_cam_to_world_mat = self.get_transformed_mat(src_cam_to_world_mat,
|
||||
est_delta_rot_mat,
|
||||
look_at_center_world)
|
||||
dst_cam_to_world_mat_numpy = dst_cam_to_world_mat.cpu().numpy()
|
||||
dst_transform = Transform.from_matrix(dst_cam_to_world_mat_numpy)
|
||||
x_d = dst_transform
|
||||
ratio = 1.0
|
||||
while(True):
|
||||
dst_cam_to_world_mat = self.get_transformed_mat(src_cam_to_world_mat,
|
||||
est_delta_rot_mat,
|
||||
look_at_center_world, ratio=ratio)
|
||||
dst_cam_to_world_mat_numpy = dst_cam_to_world_mat.cpu().numpy()
|
||||
dst_transform = Transform.from_matrix(dst_cam_to_world_mat_numpy)
|
||||
x_d = dst_transform
|
||||
|
||||
# Check if this pose available
|
||||
print("found a NBV pose")
|
||||
if(self.solve_cam_ik(self.q0, x_d)):
|
||||
self.vis_cam_pose(x_d)
|
||||
self.x_d = x_d
|
||||
self.updated = True
|
||||
print("the NBV pose is reachable")
|
||||
return
|
||||
else:
|
||||
self.unreachable_poses.append(x_d)
|
||||
self.vis_unreachable_pose(self.unreachable_poses)
|
||||
print("the NBV pose is not reachable")
|
||||
|
||||
# Check if this pose available
|
||||
print("found a NBV pose")
|
||||
if(self.solve_cam_ik(self.q0, x_d)):
|
||||
self.vis_cam_pose(x_d)
|
||||
self.x_d = x_d
|
||||
self.updated = True
|
||||
print("the NBV pose is reachable")
|
||||
import ipdb; ipdb.set_trace()
|
||||
return
|
||||
else:
|
||||
self.unreachable_poses.append(x_d)
|
||||
self.vis_unreachable_pose(self.unreachable_poses)
|
||||
print("the NBV pose is not reachable, decreasing ratio to ", ratio)
|
||||
ratio -= 0.01
|
||||
if(ratio < 0.1):
|
||||
print("the NBV pose is not reachable, giving up")
|
||||
return
|
||||
# Policy has produced an available nbv and moved to that camera pose
|
||||
if(self.updated == True):
|
||||
# Request grasping poses from GSNet
|
||||
@@ -182,7 +190,7 @@ class ActivePerceptionSingleViewPolicy(SingleViewPolicy):
|
||||
|
||||
gsnet_grasping_poses = np.asarray(self.request_grasping_pose(gsnet_input_points.tolist()))
|
||||
print(gsnet_grasping_poses[0].keys())
|
||||
import ipdb; ipdb.set_trace()
|
||||
#import ipdb; ipdb.set_trace()
|
||||
|
||||
# DEBUG: publish grasps
|
||||
# self.publish_grasps(gsnet_grasping_poses)
|
||||
@@ -318,14 +326,14 @@ class ActivePerceptionSingleViewPolicy(SingleViewPolicy):
|
||||
return torch.stack((b1, b2, b3), dim=-2)
|
||||
|
||||
|
||||
def get_transformed_mat(self, src_mat, delta_rot, target_center_w):
|
||||
def get_transformed_mat(self, src_mat, delta_rot, target_center_w, ratio=1.0):
|
||||
src_rot = src_mat[:3, :3]
|
||||
dst_rot = src_rot @ delta_rot.T
|
||||
dst_mat = torch.eye(4).to(dst_rot.device)
|
||||
dst_mat[:3, :3] = dst_rot
|
||||
distance = torch.norm(target_center_w - src_mat[:3, 3])
|
||||
z_axis_camera = dst_rot[:3, 2].reshape(-1)
|
||||
new_camera_position_w = target_center_w - distance * z_axis_camera
|
||||
new_camera_position_w = target_center_w - distance * z_axis_camera * ratio
|
||||
dst_mat[:3, 3] = new_camera_position_w
|
||||
return dst_mat
|
||||
|
||||
|
@@ -124,9 +124,7 @@ class GraspController:
|
||||
depth_img, seg_image, pose, q = self.get_state()
|
||||
current_p = pose.as_matrix()[:3,3]
|
||||
target_p = self.policy.x_d.as_matrix()[:3,3]
|
||||
linear_d = np.sqrt((current_p[0]-target_p[0])**2+
|
||||
(current_p[1]-target_p[1])**2+
|
||||
(current_p[2]-target_p[2])**2)
|
||||
linear_d = np.linalg.norm(current_p - target_p)
|
||||
if(linear_d < self.move_to_target_threshold):
|
||||
# Arrived
|
||||
moving_to_The_target = False
|
||||
@@ -139,7 +137,9 @@ class GraspController:
|
||||
depth_img, seg_image, pose, q = self.get_state()
|
||||
target_seg_id = self.get_target_id(TargetIDRequest()).id
|
||||
support_seg_id = self.get_support_id(TargetIDRequest()).id
|
||||
self.policy.update(depth_img, seg_image, target_seg_id, support_seg_id, pose, q)
|
||||
# print(target_seg_id, support_seg_id)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.policy.update(depth_img, seg_image, target_seg_id, pose, q)
|
||||
r.sleep()
|
||||
else:
|
||||
print("Unsupported policy type: "+str(self.policy.policy_type))
|
||||
|
Reference in New Issue
Block a user