support different controller for different type of policies

This commit is contained in:
0nhc 2024-10-13 06:51:46 -05:00
parent 79d709d1ac
commit 7cad070b13
3 changed files with 32 additions and 18 deletions

View File

@ -60,9 +60,6 @@ class ActivePerceptionMultiViewPolicy(MultiViewPolicy):
self.pcdvis = RealTime3DVisualizer() self.pcdvis = RealTime3DVisualizer()
def activate(self, bbox, view_sphere):
super().activate(bbox, view_sphere)
def update(self, img, seg, target_id, x, q): def update(self, img, seg, target_id, x, q):
if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable(): if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
self.done = True self.done = True

View File

@ -108,6 +108,8 @@ class GraspController:
self.policy.activate(bbox, self.view_sphere) self.policy.activate(bbox, self.view_sphere)
timer = rospy.Timer(rospy.Duration(1.0 / self.control_rate), self.send_vel_cmd) timer = rospy.Timer(rospy.Duration(1.0 / self.control_rate), self.send_vel_cmd)
r = rospy.Rate(self.policy_rate) r = rospy.Rate(self.policy_rate)
if(self.policy.policy_type=="single_view"):
while not self.policy.done: while not self.policy.done:
depth_img, seg_image, pose, q = self.get_state() depth_img, seg_image, pose, q = self.get_state()
target_seg_id = self.get_target_id(TargetIDRequest()).id target_seg_id = self.get_target_id(TargetIDRequest()).id
@ -125,6 +127,14 @@ class GraspController:
# Arrived # Arrived
moving_to_The_target = False moving_to_The_target = False
r.sleep() r.sleep()
elif(self.policy.policy_type=="multi_view"):
while not self.policy.done:
depth_img, seg_image, pose, q = self.get_state()
target_seg_id = self.get_target_id(TargetIDRequest()).id
self.policy.update(depth_img, seg_image, target_seg_id, pose, q)
r.sleep()
else:
print("Unsupported policy type: "+str(self.policy.policy_type))
# Wait for a zero command to be sent to the robot. # Wait for a zero command to be sent to the robot.
rospy.sleep(0.2) rospy.sleep(0.2)

View File

@ -67,6 +67,8 @@ class Policy:
self.done = False self.done = False
self.info = {} self.info = {}
self.policy_type = "policy"
def calibrate_task_frame(self): def calibrate_task_frame(self):
xyz = np.r_[self.bbox.center[:2] - 0.15, self.bbox.min[2] - 0.05] xyz = np.r_[self.bbox.center[:2] - 0.15, self.bbox.min[2] - 0.05]
self.T_base_task = Transform.from_translation(xyz) self.T_base_task = Transform.from_translation(xyz)
@ -106,6 +108,10 @@ def select_best_grasp(grasps, qualities):
class SingleViewPolicy(Policy): class SingleViewPolicy(Policy):
def activate(self, bbox, view_sphere):
super().activate(bbox, view_sphere)
self.policy_type = "single_view"
def update(self, img, seg, target_id, x, q): def update(self, img, seg, target_id, x, q):
linear, _ = compute_error(self.x_d, x) linear, _ = compute_error(self.x_d, x)
if np.linalg.norm(linear) < 0.02: if np.linalg.norm(linear) < 0.02:
@ -143,6 +149,7 @@ class MultiViewPolicy(Policy):
def activate(self, bbox, view_sphere): def activate(self, bbox, view_sphere):
super().activate(bbox, view_sphere) super().activate(bbox, view_sphere)
self.qual_hist = np.zeros((self.T,) + (40,) * 3, np.float32) self.qual_hist = np.zeros((self.T,) + (40,) * 3, np.float32)
self.policy_type = "multi_view"
def integrate(self, img, x, q): def integrate(self, img, x, q):
self.views.append(x) self.views.append(x)