From 52cb65dee81876de89308130a0933f14bee2cfc2 Mon Sep 17 00:00:00 2001 From: 0nhc Date: Sun, 13 Oct 2024 01:13:42 -0500 Subject: [PATCH] add seg_id interface for policy update() function --- src/active_grasp/active_perception_policy.py | 6 +++--- src/active_grasp/baselines.py | 6 +++--- src/active_grasp/controller.py | 6 ++++-- src/active_grasp/nbv.py | 2 +- src/active_grasp/policy.py | 4 ++-- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/active_grasp/active_perception_policy.py b/src/active_grasp/active_perception_policy.py index 7afe884..3bae045 100644 --- a/src/active_grasp/active_perception_policy.py +++ b/src/active_grasp/active_perception_policy.py @@ -18,8 +18,8 @@ class ActivePerceptionPolicy(MultiViewPolicy): def activate(self, bbox, view_sphere): super().activate(bbox, view_sphere) - def update(self, img, seg, x, q): - self.depth_image_to_ap_input(img, seg) + def update(self, img, seg, target_id, x, q): + self.depth_image_to_ap_input(img, seg, target_id) # if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable(): # self.done = True # else: @@ -41,7 +41,7 @@ class ActivePerceptionPolicy(MultiViewPolicy): # self.x_d = nbv - def depth_image_to_ap_input(self, depth_img, seg_img): + def depth_image_to_ap_input(self, depth_img, seg_img, target_id): K = self.intrinsic.K depth_shape = depth_img.shape seg_shape = seg_img.shape diff --git a/src/active_grasp/baselines.py b/src/active_grasp/baselines.py index 4f05ff7..5d76b88 100644 --- a/src/active_grasp/baselines.py +++ b/src/active_grasp/baselines.py @@ -4,7 +4,7 @@ from .policy import SingleViewPolicy, MultiViewPolicy, compute_error class InitialView(SingleViewPolicy): - def update(self, img, x, q): + def update(self, img, seg, target_id, x, q): self.x_d = x super().update(img, x, q) @@ -22,7 +22,7 @@ class TopTrajectory(MultiViewPolicy): self.x_d = self.view_sphere.get_view(0.0, 0.0) self.done = False if self.solve_cam_ik(self.q0, self.x_d) else True - def update(self, img, x, q): + def update(self, img, seg, target_id, x, q): self.integrate(img, x, q) linear, _ = compute_error(self.x_d, x) if np.linalg.norm(linear) < 0.02: @@ -33,5 +33,5 @@ class FixedTrajectory(MultiViewPolicy): def activate(self, bbox, view_sphere): pass - def update(self, img, x, q): + def update(self, img, seg, target_id, x, q): pass diff --git a/src/active_grasp/controller.py b/src/active_grasp/controller.py index 852ae09..6780e1a 100644 --- a/src/active_grasp/controller.py +++ b/src/active_grasp/controller.py @@ -9,7 +9,7 @@ import trimesh from .bbox import from_bbox_msg from .timer import Timer -from active_grasp.srv import Reset, ResetRequest +from active_grasp.srv import * from robot_helpers.ros import tf from robot_helpers.ros.conversions import * from robot_helpers.ros.panda import PandaArmClient, PandaGripperClient @@ -43,6 +43,7 @@ class GraspController: self.switch_controller = rospy.ServiceProxy( "controller_manager/switch_controller", SwitchController ) + self.get_target_id = rospy.ServiceProxy("get_target_seg_id", TargetID) def init_robot_connection(self): self.arm = PandaArmClient() @@ -108,7 +109,8 @@ class GraspController: r = rospy.Rate(self.policy_rate) while not self.policy.done: depth_img, seg_image, pose, q = self.get_state() - self.policy.update(depth_img, seg_image, pose, q) + target_seg_id = self.get_target_id(TargetIDRequest()) + self.policy.update(depth_img, seg_image, target_seg_id, pose, q) r.sleep() rospy.sleep(0.2) # Wait for a zero command to be sent to the robot. self.policy.deactivate() diff --git a/src/active_grasp/nbv.py b/src/active_grasp/nbv.py index 902b0c5..06501e5 100644 --- a/src/active_grasp/nbv.py +++ b/src/active_grasp/nbv.py @@ -84,7 +84,7 @@ class NextBestView(MultiViewPolicy): def activate(self, bbox, view_sphere): super().activate(bbox, view_sphere) - def update(self, img, seg, x, q): + def update(self, img, seg, target_id, x, q): if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable(): self.done = True else: diff --git a/src/active_grasp/policy.py b/src/active_grasp/policy.py index f0fca95..aba967c 100644 --- a/src/active_grasp/policy.py +++ b/src/active_grasp/policy.py @@ -75,7 +75,7 @@ class Policy: rospy.sleep(1.0) # Wait for tf tree to be updated self.vis.roi(self.task_frame, 0.3) - def update(self, img, seg, x, q): + def update(self, img, seg, target_id, x, q): raise NotImplementedError def filter_grasps(self, out, q): @@ -106,7 +106,7 @@ def select_best_grasp(grasps, qualities): class SingleViewPolicy(Policy): - def update(self, img, seg, x, q): + def update(self, img, seg, target_id, x, q): linear, _ = compute_error(self.x_d, x) if np.linalg.norm(linear) < 0.02: self.views.append(x)