add seg_id interface for policy update() function

This commit is contained in:
0nhc 2024-10-13 01:13:42 -05:00
parent 60e327b8fb
commit 52cb65dee8
5 changed files with 13 additions and 11 deletions

View File

@ -18,8 +18,8 @@ class ActivePerceptionPolicy(MultiViewPolicy):
def activate(self, bbox, view_sphere): def activate(self, bbox, view_sphere):
super().activate(bbox, view_sphere) super().activate(bbox, view_sphere)
def update(self, img, seg, x, q): def update(self, img, seg, target_id, x, q):
self.depth_image_to_ap_input(img, seg) self.depth_image_to_ap_input(img, seg, target_id)
# if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable(): # if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
# self.done = True # self.done = True
# else: # else:
@ -41,7 +41,7 @@ class ActivePerceptionPolicy(MultiViewPolicy):
# self.x_d = nbv # self.x_d = nbv
def depth_image_to_ap_input(self, depth_img, seg_img): def depth_image_to_ap_input(self, depth_img, seg_img, target_id):
K = self.intrinsic.K K = self.intrinsic.K
depth_shape = depth_img.shape depth_shape = depth_img.shape
seg_shape = seg_img.shape seg_shape = seg_img.shape

View File

@ -4,7 +4,7 @@ from .policy import SingleViewPolicy, MultiViewPolicy, compute_error
class InitialView(SingleViewPolicy): class InitialView(SingleViewPolicy):
def update(self, img, x, q): def update(self, img, seg, target_id, x, q):
self.x_d = x self.x_d = x
super().update(img, x, q) super().update(img, x, q)
@ -22,7 +22,7 @@ class TopTrajectory(MultiViewPolicy):
self.x_d = self.view_sphere.get_view(0.0, 0.0) self.x_d = self.view_sphere.get_view(0.0, 0.0)
self.done = False if self.solve_cam_ik(self.q0, self.x_d) else True self.done = False if self.solve_cam_ik(self.q0, self.x_d) else True
def update(self, img, x, q): def update(self, img, seg, target_id, x, q):
self.integrate(img, x, q) self.integrate(img, x, q)
linear, _ = compute_error(self.x_d, x) linear, _ = compute_error(self.x_d, x)
if np.linalg.norm(linear) < 0.02: if np.linalg.norm(linear) < 0.02:
@ -33,5 +33,5 @@ class FixedTrajectory(MultiViewPolicy):
def activate(self, bbox, view_sphere): def activate(self, bbox, view_sphere):
pass pass
def update(self, img, x, q): def update(self, img, seg, target_id, x, q):
pass pass

View File

@ -9,7 +9,7 @@ import trimesh
from .bbox import from_bbox_msg from .bbox import from_bbox_msg
from .timer import Timer from .timer import Timer
from active_grasp.srv import Reset, ResetRequest from active_grasp.srv import *
from robot_helpers.ros import tf from robot_helpers.ros import tf
from robot_helpers.ros.conversions import * from robot_helpers.ros.conversions import *
from robot_helpers.ros.panda import PandaArmClient, PandaGripperClient from robot_helpers.ros.panda import PandaArmClient, PandaGripperClient
@ -43,6 +43,7 @@ class GraspController:
self.switch_controller = rospy.ServiceProxy( self.switch_controller = rospy.ServiceProxy(
"controller_manager/switch_controller", SwitchController "controller_manager/switch_controller", SwitchController
) )
self.get_target_id = rospy.ServiceProxy("get_target_seg_id", TargetID)
def init_robot_connection(self): def init_robot_connection(self):
self.arm = PandaArmClient() self.arm = PandaArmClient()
@ -108,7 +109,8 @@ class GraspController:
r = rospy.Rate(self.policy_rate) r = rospy.Rate(self.policy_rate)
while not self.policy.done: while not self.policy.done:
depth_img, seg_image, pose, q = self.get_state() depth_img, seg_image, pose, q = self.get_state()
self.policy.update(depth_img, seg_image, pose, q) target_seg_id = self.get_target_id(TargetIDRequest())
self.policy.update(depth_img, seg_image, target_seg_id, pose, q)
r.sleep() r.sleep()
rospy.sleep(0.2) # Wait for a zero command to be sent to the robot. rospy.sleep(0.2) # Wait for a zero command to be sent to the robot.
self.policy.deactivate() self.policy.deactivate()

View File

@ -84,7 +84,7 @@ class NextBestView(MultiViewPolicy):
def activate(self, bbox, view_sphere): def activate(self, bbox, view_sphere):
super().activate(bbox, view_sphere) super().activate(bbox, view_sphere)
def update(self, img, seg, x, q): def update(self, img, seg, target_id, x, q):
if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable(): if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
self.done = True self.done = True
else: else:

View File

@ -75,7 +75,7 @@ class Policy:
rospy.sleep(1.0) # Wait for tf tree to be updated rospy.sleep(1.0) # Wait for tf tree to be updated
self.vis.roi(self.task_frame, 0.3) self.vis.roi(self.task_frame, 0.3)
def update(self, img, seg, x, q): def update(self, img, seg, target_id, x, q):
raise NotImplementedError raise NotImplementedError
def filter_grasps(self, out, q): def filter_grasps(self, out, q):
@ -106,7 +106,7 @@ def select_best_grasp(grasps, qualities):
class SingleViewPolicy(Policy): class SingleViewPolicy(Policy):
def update(self, img, seg, x, q): def update(self, img, seg, target_id, x, q):
linear, _ = compute_error(self.x_d, x) linear, _ = compute_error(self.x_d, x)
if np.linalg.norm(linear) < 0.02: if np.linalg.norm(linear) < 0.02:
self.views.append(x) self.views.append(x)