add seg_id interface for policy update() function
This commit is contained in:
parent
60e327b8fb
commit
52cb65dee8
@ -18,8 +18,8 @@ class ActivePerceptionPolicy(MultiViewPolicy):
|
|||||||
def activate(self, bbox, view_sphere):
|
def activate(self, bbox, view_sphere):
|
||||||
super().activate(bbox, view_sphere)
|
super().activate(bbox, view_sphere)
|
||||||
|
|
||||||
def update(self, img, seg, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
self.depth_image_to_ap_input(img, seg)
|
self.depth_image_to_ap_input(img, seg, target_id)
|
||||||
# if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
# if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
||||||
# self.done = True
|
# self.done = True
|
||||||
# else:
|
# else:
|
||||||
@ -41,7 +41,7 @@ class ActivePerceptionPolicy(MultiViewPolicy):
|
|||||||
|
|
||||||
# self.x_d = nbv
|
# self.x_d = nbv
|
||||||
|
|
||||||
def depth_image_to_ap_input(self, depth_img, seg_img):
|
def depth_image_to_ap_input(self, depth_img, seg_img, target_id):
|
||||||
K = self.intrinsic.K
|
K = self.intrinsic.K
|
||||||
depth_shape = depth_img.shape
|
depth_shape = depth_img.shape
|
||||||
seg_shape = seg_img.shape
|
seg_shape = seg_img.shape
|
||||||
|
@ -4,7 +4,7 @@ from .policy import SingleViewPolicy, MultiViewPolicy, compute_error
|
|||||||
|
|
||||||
|
|
||||||
class InitialView(SingleViewPolicy):
|
class InitialView(SingleViewPolicy):
|
||||||
def update(self, img, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
self.x_d = x
|
self.x_d = x
|
||||||
super().update(img, x, q)
|
super().update(img, x, q)
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ class TopTrajectory(MultiViewPolicy):
|
|||||||
self.x_d = self.view_sphere.get_view(0.0, 0.0)
|
self.x_d = self.view_sphere.get_view(0.0, 0.0)
|
||||||
self.done = False if self.solve_cam_ik(self.q0, self.x_d) else True
|
self.done = False if self.solve_cam_ik(self.q0, self.x_d) else True
|
||||||
|
|
||||||
def update(self, img, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
self.integrate(img, x, q)
|
self.integrate(img, x, q)
|
||||||
linear, _ = compute_error(self.x_d, x)
|
linear, _ = compute_error(self.x_d, x)
|
||||||
if np.linalg.norm(linear) < 0.02:
|
if np.linalg.norm(linear) < 0.02:
|
||||||
@ -33,5 +33,5 @@ class FixedTrajectory(MultiViewPolicy):
|
|||||||
def activate(self, bbox, view_sphere):
|
def activate(self, bbox, view_sphere):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def update(self, img, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
pass
|
pass
|
||||||
|
@ -9,7 +9,7 @@ import trimesh
|
|||||||
|
|
||||||
from .bbox import from_bbox_msg
|
from .bbox import from_bbox_msg
|
||||||
from .timer import Timer
|
from .timer import Timer
|
||||||
from active_grasp.srv import Reset, ResetRequest
|
from active_grasp.srv import *
|
||||||
from robot_helpers.ros import tf
|
from robot_helpers.ros import tf
|
||||||
from robot_helpers.ros.conversions import *
|
from robot_helpers.ros.conversions import *
|
||||||
from robot_helpers.ros.panda import PandaArmClient, PandaGripperClient
|
from robot_helpers.ros.panda import PandaArmClient, PandaGripperClient
|
||||||
@ -43,6 +43,7 @@ class GraspController:
|
|||||||
self.switch_controller = rospy.ServiceProxy(
|
self.switch_controller = rospy.ServiceProxy(
|
||||||
"controller_manager/switch_controller", SwitchController
|
"controller_manager/switch_controller", SwitchController
|
||||||
)
|
)
|
||||||
|
self.get_target_id = rospy.ServiceProxy("get_target_seg_id", TargetID)
|
||||||
|
|
||||||
def init_robot_connection(self):
|
def init_robot_connection(self):
|
||||||
self.arm = PandaArmClient()
|
self.arm = PandaArmClient()
|
||||||
@ -108,7 +109,8 @@ class GraspController:
|
|||||||
r = rospy.Rate(self.policy_rate)
|
r = rospy.Rate(self.policy_rate)
|
||||||
while not self.policy.done:
|
while not self.policy.done:
|
||||||
depth_img, seg_image, pose, q = self.get_state()
|
depth_img, seg_image, pose, q = self.get_state()
|
||||||
self.policy.update(depth_img, seg_image, pose, q)
|
target_seg_id = self.get_target_id(TargetIDRequest())
|
||||||
|
self.policy.update(depth_img, seg_image, target_seg_id, pose, q)
|
||||||
r.sleep()
|
r.sleep()
|
||||||
rospy.sleep(0.2) # Wait for a zero command to be sent to the robot.
|
rospy.sleep(0.2) # Wait for a zero command to be sent to the robot.
|
||||||
self.policy.deactivate()
|
self.policy.deactivate()
|
||||||
|
@ -84,7 +84,7 @@ class NextBestView(MultiViewPolicy):
|
|||||||
def activate(self, bbox, view_sphere):
|
def activate(self, bbox, view_sphere):
|
||||||
super().activate(bbox, view_sphere)
|
super().activate(bbox, view_sphere)
|
||||||
|
|
||||||
def update(self, img, seg, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
||||||
self.done = True
|
self.done = True
|
||||||
else:
|
else:
|
||||||
|
@ -75,7 +75,7 @@ class Policy:
|
|||||||
rospy.sleep(1.0) # Wait for tf tree to be updated
|
rospy.sleep(1.0) # Wait for tf tree to be updated
|
||||||
self.vis.roi(self.task_frame, 0.3)
|
self.vis.roi(self.task_frame, 0.3)
|
||||||
|
|
||||||
def update(self, img, seg, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def filter_grasps(self, out, q):
|
def filter_grasps(self, out, q):
|
||||||
@ -106,7 +106,7 @@ def select_best_grasp(grasps, qualities):
|
|||||||
|
|
||||||
|
|
||||||
class SingleViewPolicy(Policy):
|
class SingleViewPolicy(Policy):
|
||||||
def update(self, img, seg, x, q):
|
def update(self, img, seg, target_id, x, q):
|
||||||
linear, _ = compute_error(self.x_d, x)
|
linear, _ = compute_error(self.x_d, x)
|
||||||
if np.linalg.norm(linear) < 0.02:
|
if np.linalg.norm(linear) < 0.02:
|
||||||
self.views.append(x)
|
self.views.append(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user