nbv_sim/src/active_grasp/controller.py

231 lines
9.0 KiB
Python
Raw Normal View History

2021-09-04 15:50:29 +02:00
from controller_manager_msgs.srv import *
2021-08-03 18:11:30 +02:00
import copy
import cv_bridge
2021-09-03 22:39:17 +02:00
from geometry_msgs.msg import Twist
2021-07-06 14:00:04 +02:00
import numpy as np
import rospy
2021-08-06 15:23:50 +02:00
from sensor_msgs.msg import Image
2021-10-12 17:16:41 +02:00
import trimesh
2021-07-06 14:00:04 +02:00
2021-08-05 13:45:22 +02:00
from .bbox import from_bbox_msg
from .timer import Timer
2021-07-07 16:29:50 +02:00
from active_grasp.srv import Reset, ResetRequest
2021-08-03 18:11:30 +02:00
from robot_helpers.ros import tf
from robot_helpers.ros.conversions import *
2021-09-12 14:40:17 +02:00
from robot_helpers.ros.panda import PandaArmClient, PandaGripperClient
2021-10-12 17:16:41 +02:00
from robot_helpers.ros.moveit import MoveItClient, create_collision_object_from_mesh
2021-07-22 11:05:30 +02:00
from robot_helpers.spatial import Rotation, Transform
2021-09-11 20:49:55 +02:00
from vgn.utils import look_at, cartesian_to_spherical, spherical_to_cartesian
2021-07-06 14:00:04 +02:00
class GraspController:
2021-08-06 15:23:50 +02:00
def __init__(self, policy):
self.policy = policy
2021-08-03 18:11:30 +02:00
self.load_parameters()
2021-09-04 15:50:29 +02:00
self.init_service_proxies()
2021-08-03 18:11:30 +02:00
self.init_robot_connection()
2021-09-06 16:28:20 +02:00
self.init_moveit()
2021-08-03 18:11:30 +02:00
self.init_camera_stream()
def load_parameters(self):
self.base_frame = rospy.get_param("~base_frame_id")
2021-09-11 20:49:55 +02:00
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
2021-08-03 18:11:30 +02:00
self.cam_frame = rospy.get_param("~camera/frame_id")
self.depth_topic = rospy.get_param("~camera/depth_topic")
2021-09-11 20:49:55 +02:00
self.min_z_dist = rospy.get_param("~camera/min_z_dist")
self.control_rate = rospy.get_param("~control_rate")
self.linear_vel = rospy.get_param("~linear_vel")
self.policy_rate = rospy.get_param("policy/rate")
2021-08-03 18:11:30 +02:00
2021-09-04 15:50:29 +02:00
def init_service_proxies(self):
self.reset_env = rospy.ServiceProxy("reset", Reset)
self.switch_controller = rospy.ServiceProxy(
"controller_manager/switch_controller", SwitchController
)
2021-08-03 18:11:30 +02:00
def init_robot_connection(self):
2021-09-12 14:40:17 +02:00
self.arm = PandaArmClient()
2021-07-22 11:05:30 +02:00
self.gripper = PandaGripperClient()
topic = rospy.get_param("cartesian_velocity_controller/topic")
self.cartesian_vel_pub = rospy.Publisher(topic, Twist, queue_size=10)
2021-09-06 16:28:20 +02:00
def init_moveit(self):
2021-09-04 15:50:29 +02:00
self.moveit = MoveItClient("panda_arm")
2021-09-11 20:49:55 +02:00
rospy.sleep(1.0) # Wait for connections to be established.
2021-10-25 10:47:01 +02:00
self.moveit.move_group.set_planner_id("RRTstarkConfigDefault")
2021-10-27 14:29:34 +02:00
self.moveit.move_group.set_planning_time(3.0)
2021-09-04 15:50:29 +02:00
def switch_to_cartesian_velocity_control(self):
req = SwitchControllerRequest()
req.start_controllers = ["cartesian_velocity_controller"]
req.stop_controllers = ["position_joint_trajectory_controller"]
2021-12-03 13:52:07 +01:00
req.strictness = 1
2021-09-04 15:50:29 +02:00
self.switch_controller(req)
def switch_to_joint_trajectory_control(self):
req = SwitchControllerRequest()
req.start_controllers = ["position_joint_trajectory_controller"]
req.stop_controllers = ["cartesian_velocity_controller"]
2021-12-03 13:52:07 +01:00
req.strictness = 1
2021-09-04 15:50:29 +02:00
self.switch_controller(req)
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
def init_camera_stream(self):
self.cv_bridge = cv_bridge.CvBridge()
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
def sensor_cb(self, msg):
self.latest_depth_msg = msg
2021-07-06 14:00:04 +02:00
def run(self):
2021-08-03 18:11:30 +02:00
bbox = self.reset()
2021-09-04 15:50:29 +02:00
self.switch_to_cartesian_velocity_control()
2021-07-22 11:05:30 +02:00
with Timer("search_time"):
2021-08-03 18:11:30 +02:00
grasp = self.search_grasp(bbox)
if grasp:
self.switch_to_joint_trajectory_control()
with Timer("grasp_time"):
res = self.execute_grasp(grasp)
else:
res = "aborted"
2021-08-03 18:11:30 +02:00
return self.collect_info(res)
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
def reset(self):
2021-09-12 17:55:42 +02:00
Timer.reset()
self.moveit.scene.clear()
2021-08-03 18:11:30 +02:00
res = self.reset_env(ResetRequest())
2021-09-11 20:49:55 +02:00
rospy.sleep(1.0) # Wait for the TF tree to be updated.
2021-07-07 16:29:50 +02:00
return from_bbox_msg(res.bbox)
2021-07-06 14:00:04 +02:00
2021-08-03 18:11:30 +02:00
def search_grasp(self, bbox):
2021-09-12 00:21:58 +02:00
self.view_sphere = ViewHalfSphere(bbox, self.min_z_dist)
2021-09-11 20:49:55 +02:00
self.policy.activate(bbox, self.view_sphere)
timer = rospy.Timer(rospy.Duration(1.0 / self.control_rate), self.send_vel_cmd)
r = rospy.Rate(self.policy_rate)
2021-09-04 15:50:29 +02:00
while not self.policy.done:
2021-09-12 14:40:17 +02:00
img, pose, q = self.get_state()
self.policy.update(img, pose, q)
2021-09-04 15:50:29 +02:00
r.sleep()
rospy.sleep(0.2) # Wait for a zero command to be sent to the robot.
2021-12-06 10:49:18 +01:00
self.policy.deactivate()
timer.shutdown()
2021-07-06 14:00:04 +02:00
return self.policy.best_grasp
2021-08-03 18:11:30 +02:00
def get_state(self):
2021-09-12 14:40:17 +02:00
q, _ = self.arm.get_state()
2021-08-03 18:11:30 +02:00
msg = copy.deepcopy(self.latest_depth_msg)
2021-12-03 14:06:26 +01:00
img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32) * 0.001
2021-09-03 22:39:17 +02:00
pose = tf.lookup(self.base_frame, self.cam_frame, msg.header.stamp)
2021-09-12 14:40:17 +02:00
return img, pose, q
2021-08-03 18:11:30 +02:00
2021-09-11 20:49:55 +02:00
def send_vel_cmd(self, event):
if self.policy.x_d is None or self.policy.done:
cmd = np.zeros(6)
else:
x = tf.lookup(self.base_frame, self.cam_frame)
cmd = self.compute_velocity_cmd(self.policy.x_d, x)
self.cartesian_vel_pub.publish(to_twist_msg(cmd))
def compute_velocity_cmd(self, x_d, x):
2021-09-14 12:37:48 +02:00
r, theta, phi = cartesian_to_spherical(x.translation - self.view_sphere.center)
e_t = x_d.translation - x.translation
e_n = (x.translation - self.view_sphere.center) * (self.view_sphere.r - r) / r
linear = 1.0 * e_t + 6.0 * (r < self.view_sphere.r) * e_n
2021-11-04 15:21:36 +01:00
scale = np.linalg.norm(linear) + 1e-6
2021-09-11 20:49:55 +02:00
linear *= np.clip(scale, 0.0, self.linear_vel) / scale
angular = self.view_sphere.get_view(theta, phi).rotation * x.rotation.inv()
2021-11-10 16:38:59 +01:00
angular = 1.0 * angular.as_rotvec()
2021-09-11 20:49:55 +02:00
return np.r_[linear, angular]
2021-10-13 15:44:06 +02:00
def execute_grasp(self, grasp):
2021-11-08 11:43:06 +01:00
self.create_collision_scene()
2021-10-13 15:44:06 +02:00
T_base_grasp = self.postprocess(grasp.pose)
self.gripper.move(0.08)
2022-01-11 11:19:21 +01:00
T_base_approach = T_base_grasp * Transform.t_[0, 0, -0.06] * self.T_grasp_ee
2021-11-08 15:09:27 +01:00
success, plan = self.moveit.plan(T_base_approach, 0.2, 0.2)
2021-11-08 13:13:43 +01:00
if success:
self.moveit.scene.clear()
self.moveit.execute(plan)
2021-11-08 13:13:43 +01:00
rospy.sleep(0.5) # Wait for the planning scene to be updated
self.moveit.gotoL(T_base_grasp * self.T_grasp_ee)
rospy.sleep(0.5)
2021-11-08 13:13:43 +01:00
self.gripper.grasp()
2022-01-11 11:19:21 +01:00
T_base_retreat = Transform.t_[0, 0, 0.05] * T_base_grasp * self.T_grasp_ee
2021-11-08 13:13:43 +01:00
self.moveit.gotoL(T_base_retreat)
rospy.sleep(1.0) # Wait to see whether the object slides out of the hand
success = self.gripper.read() > 0.002
2021-11-09 23:00:59 +01:00
return "succeeded" if success else "failed"
else:
return "no_motion_plan_found"
2021-10-13 15:44:06 +02:00
2021-11-08 11:43:06 +01:00
def create_collision_scene(self):
# Segment support surface
cloud = self.policy.tsdf.get_scene_cloud()
cloud = cloud.transform(self.policy.T_base_task.as_matrix())
2021-10-12 17:57:09 +02:00
_, inliers = cloud.segment_plane(0.01, 3, 1000)
support_cloud = cloud.select_by_index(inliers)
cloud = cloud.select_by_index(inliers, invert=True)
2021-11-09 15:10:42 +01:00
# o3d.io.write_point_cloud(f"{time.time():.0f}.pcd", cloud)
2021-11-08 11:43:06 +01:00
# Add collision object for the support
self.add_collision_mesh("support", compute_convex_hull(support_cloud))
# Cluster cloud
2021-11-09 15:10:42 +01:00
labels = np.array(cloud.cluster_dbscan(eps=0.01, min_points=8))
2021-11-08 11:43:06 +01:00
# Generate convex collision objects for each segment
self.hulls = []
for label in range(labels.max() + 1):
segment = cloud.select_by_index(np.flatnonzero(labels == label))
2021-11-04 20:08:16 +01:00
try:
hull = compute_convex_hull(segment)
name = f"object_{label}"
self.add_collision_mesh(name, hull)
2021-11-08 11:43:06 +01:00
self.hulls.append(hull)
2021-11-04 20:08:16 +01:00
except:
# Qhull fails in some edge cases
pass
2021-10-12 17:57:09 +02:00
def add_collision_mesh(self, name, mesh):
frame, pose = self.base_frame, Transform.identity()
co = create_collision_object_from_mesh(name, frame, pose, mesh)
self.moveit.scene.add_object(co)
2021-08-03 18:11:30 +02:00
def postprocess(self, T_base_grasp):
rot = T_base_grasp.rotation
2021-09-04 15:50:29 +02:00
if rot.as_matrix()[:, 0][0] < 0: # Ensure that the camera is pointing forward
2021-08-03 18:11:30 +02:00
T_base_grasp.rotation = rot * Rotation.from_euler("z", np.pi)
2022-01-11 11:19:21 +01:00
T_base_grasp *= Transform.t_[0.0, 0.0, 0.01]
2021-08-03 18:11:30 +02:00
return T_base_grasp
2021-07-12 13:12:36 +02:00
2021-08-03 18:11:30 +02:00
def collect_info(self, result):
2021-08-25 18:29:10 +02:00
points = [p.translation for p in self.policy.views]
2021-07-12 13:12:36 +02:00
d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])])
info = {
"result": result,
2021-08-25 18:29:10 +02:00
"view_count": len(points),
2021-07-22 11:05:30 +02:00
"distance": d,
2021-07-12 13:12:36 +02:00
}
2021-09-12 11:29:58 +02:00
info.update(self.policy.info)
2021-07-12 13:12:36 +02:00
info.update(Timer.timers)
return info
2021-09-11 20:49:55 +02:00
2021-10-12 17:57:09 +02:00
def compute_convex_hull(cloud):
2021-11-04 15:04:10 +01:00
hull, _ = cloud.compute_convex_hull()
triangles, vertices = np.asarray(hull.triangles), np.asarray(hull.vertices)
return trimesh.base.Trimesh(vertices, triangles)
2021-10-12 17:57:09 +02:00
2021-09-11 20:49:55 +02:00
class ViewHalfSphere:
2021-09-12 00:21:58 +02:00
def __init__(self, bbox, min_z_dist):
2021-09-11 20:49:55 +02:00
self.center = bbox.center
self.r = 0.5 * bbox.size[2] + min_z_dist
def get_view(self, theta, phi):
eye = self.center + spherical_to_cartesian(self.r, theta, phi)
up = np.r_[1.0, 0.0, 0.0]
return look_at(eye, self.center, up)
def sample_view(self):
raise NotImplementedError