diff --git a/active_grasp/baselines.py b/active_grasp/baselines.py index a4b0820..32afbb1 100644 --- a/active_grasp/baselines.py +++ b/active_grasp/baselines.py @@ -12,9 +12,9 @@ class SingleView(BasePolicy): Process a single image from the initial viewpoint. """ - def update(self): - self._integrate_latest_image() - self.best_grasp = self._predict_best_grasp() + def update(self, img, extrinsic): + self.integrate_img(img, extrinsic) + self.best_grasp = self.predict_best_grasp() self.done = True @@ -25,21 +25,17 @@ class TopView(BasePolicy): def activate(self, bbox): super().activate(bbox) - center = (bbox.min + bbox.max) / 2.0 - eye = np.r_[center[:2], center[2] + 0.3] + eye = np.r_[self.center[:2], self.center[2] + 0.3] up = np.r_[1.0, 0.0, 0.0] - self.target = self.T_B_task * (self.T_EE_cam * look_at(eye, center, up)).inv() - - def update(self): - current = tf.lookup(self.base_frame, self.ee_frame) - error = current.translation - self.target.translation + self.target = look_at(eye, self.center, up) + def update(self, img, extrinsic): + self.integrate_img(img, extrinsic) + error = extrinsic.translation - self.target.translation if np.linalg.norm(error) < 0.01: - self.best_grasp = self._predict_best_grasp() + self.best_grasp = self.predict_best_grasp() self.done = True - else: - self._integrate_latest_image() - return self.target + return self.target class RandomView(BasePolicy): @@ -47,31 +43,25 @@ class RandomView(BasePolicy): Move the camera to a random viewpoint on a circle centered above the target. """ - def __init__(self): - super().__init__() - self.r = 0.06 - self.h = 0.3 + def __init__(self, intrinsic): + super().__init__(intrinsic) + self.r = 0.06 # radius of the circle + self.h = 0.3 # distance above bbox center def activate(self, bbox): super().activate(bbox) - circle_center = (bbox.min + bbox.max) / 2.0 - circle_center[2] += self.h t = np.random.uniform(np.pi, 3.0 * np.pi) - eye = circle_center + np.r_[self.r * np.cos(t), self.r * np.sin(t), 0] - center = (self.bbox.min + self.bbox.max) / 2.0 + eye = self.center + np.r_[self.r * np.cos(t), self.r * np.sin(t), self.h] up = np.r_[1.0, 0.0, 0.0] - self.target = self.T_B_task * (self.T_EE_cam * look_at(eye, center, up)).inv() - - def update(self): - current = tf.lookup(self.base_frame, self.ee_frame) - error = current.translation - self.target.translation + self.target = look_at(eye, self.center, up) + def update(self, img, extrinsic): + self.integrate_img(img, extrinsic) + error = extrinsic.translation - self.target.translation if np.linalg.norm(error) < 0.01: - self.best_grasp = self._predict_best_grasp() + self.best_grasp = self.predict_best_grasp() self.done = True - else: - self._integrate_latest_image() - return self.target + return self.target class FixedTrajectory(BasePolicy): @@ -79,9 +69,9 @@ class FixedTrajectory(BasePolicy): Follow a pre-defined circular trajectory centered above the target object. """ - def __init__(self): - super().__init__() - self.r = 0.06 + def __init__(self, intrinsic): + super().__init__(intrinsic) + self.r = 0.08 self.h = 0.3 self.duration = 6.0 self.m = scipy.interpolate.interp1d([0, self.duration], [np.pi, 3.0 * np.pi]) @@ -89,21 +79,18 @@ class FixedTrajectory(BasePolicy): def activate(self, bbox): super().activate(bbox) self.tic = rospy.Time.now() - self.circle_center = (bbox.min + bbox.max) / 2.0 - self.circle_center[2] += self.h - def update(self): + def update(self, img, extrinsic): + self.integrate_img(img, extrinsic) elapsed_time = (rospy.Time.now() - self.tic).to_sec() if elapsed_time > self.duration: - self.best_grasp = self._predict_best_grasp() + self.best_grasp = self.predict_best_grasp() self.done = True else: - self._integrate_latest_image() t = self.m(elapsed_time) - eye = self.circle_center + np.r_[self.r * np.cos(t), self.r * np.sin(t), 0] - center = (self.bbox.min + self.bbox.max) / 2.0 + eye = self.center + np.r_[self.r * np.cos(t), self.r * np.sin(t), self.h] up = np.r_[1.0, 0.0, 0.0] - target = self.T_B_task * (self.T_EE_cam * look_at(eye, center, up)).inv() + target = look_at(eye, self.center, up) return target @@ -114,24 +101,24 @@ class AlignmentView(BasePolicy): def activate(self, bbox): super().activate(bbox) - self._integrate_latest_image() - self.best_grasp = self._predict_best_grasp() - if self.best_grasp: - R, t = self.best_grasp.rotation, self.best_grasp.translation - center = t + self.target = None + + def update(self, img, extrinsic): + self.integrate_img(img, extrinsic) + + if not self.target: + grasp = self.predict_best_grasp() + if not grasp: + self.done = True + return + R, t = grasp.pose.rotation, grasp.pose.translation eye = R.apply([0.0, 0.0, -0.16]) + t + center = t up = np.r_[1.0, 0.0, 0.0] - self.target = (self.T_EE_cam * look_at(eye, center, up)).inv() - else: - self.done = True - - def update(self): - current = tf.lookup(self.base_frame, self.ee_frame) - error = current.translation - self.target.translation + self.target = look_at(eye, center, up) + error = extrinsic.translation - self.target.translation if np.linalg.norm(error) < 0.01: - self.best_grasp = self._predict_best_grasp() + self.best_grasp = self.predict_best_grasp() self.done = True - else: - self._integrate_latest_image() - return self.target + return self.target diff --git a/active_grasp/controller.py b/active_grasp/controller.py index ce05e4c..519e2f3 100644 --- a/active_grasp/controller.py +++ b/active_grasp/controller.py @@ -1,77 +1,113 @@ +import copy +import cv_bridge from geometry_msgs.msg import PoseStamped import numpy as np import rospy +from sensor_msgs.msg import CameraInfo, Image import time from active_grasp.bbox import from_bbox_msg +from active_grasp.policy import make from active_grasp.srv import Reset, ResetRequest -from robot_helpers.ros.conversions import to_pose_stamped_msg +from robot_helpers.ros import tf +from robot_helpers.ros.conversions import * from robot_helpers.ros.panda import PandaGripperClient from robot_helpers.spatial import Rotation, Transform class GraspController: - def __init__(self, policy): - self.policy = policy - self._reset_env = rospy.ServiceProxy("reset", Reset) - self._load_parameters() - self._init_robot_control() + def __init__(self, policy_id): + self.reset_env = rospy.ServiceProxy("reset", Reset) + self.load_parameters() + self.lookup_transforms() + self.init_robot_connection() + self.init_camera_stream() + self.make_policy(policy_id) - def _load_parameters(self): - self.T_G_EE = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv() + def load_parameters(self): + self.base_frame = rospy.get_param("~base_frame_id") + self.ee_frame = rospy.get_param("~ee_frame_id") + self.cam_frame = rospy.get_param("~camera/frame_id") + self.info_topic = rospy.get_param("~camera/info_topic") + self.depth_topic = rospy.get_param("~camera/depth_topic") + self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv() - def _init_robot_control(self): + def lookup_transforms(self): + self.T_ee_cam = tf.lookup(self.ee_frame, self.cam_frame) + + def init_robot_connection(self): self.target_pose_pub = rospy.Publisher("command", PoseStamped, queue_size=10) self.gripper = PandaGripperClient() - def _send_cmd(self, pose): - msg = to_pose_stamped_msg(pose, "panda_link0") + def send_cmd(self, pose): + msg = to_pose_stamped_msg(pose, self.base_frame) self.target_pose_pub.publish(msg) - def run(self): - bbox = self._reset() - with Timer("search_time"): - grasp = self._search_grasp(bbox) - res = self._execute_grasp(grasp) - return self._collect_info(res) + def init_camera_stream(self): + msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0)) + self.intrinsic = from_camera_info_msg(msg) + self.cv_bridge = cv_bridge.CvBridge() + rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1) - def _reset(self): - res = self._reset_env(ResetRequest()) + def sensor_cb(self, msg): + self.latest_depth_msg = msg + + def make_policy(self, name): + self.policy = make(name, self.intrinsic) + + def run(self): + bbox = self.reset() + with Timer("search_time"): + grasp = self.search_grasp(bbox) + res = self.execute_grasp(grasp) + return self.collect_info(res) + + def reset(self): + res = self.reset_env(ResetRequest()) rospy.sleep(1.0) # wait for states to be updated return from_bbox_msg(res.bbox) - def _search_grasp(self, bbox): + def search_grasp(self, bbox): self.policy.activate(bbox) r = rospy.Rate(self.policy.rate) while True: - cmd = self.policy.update() + img, extrinsic = self.get_state() + next_extrinsic = self.policy.update(img, extrinsic) if self.policy.done: break - self._send_cmd(cmd) + self.send_cmd((self.T_ee_cam * next_extrinsic).inv()) r.sleep() return self.policy.best_grasp - def _execute_grasp(self, grasp): + def get_state(self): + msg = copy.deepcopy(self.latest_depth_msg) + img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32) + extrinsic = tf.lookup(self.cam_frame, self.base_frame, msg.header.stamp) + return img, extrinsic + + def execute_grasp(self, grasp): if not grasp: return "aborted" - T_B_G = self._postprocess(grasp) + T_base_grasp = self.postprocess(grasp.pose) self.gripper.move(0.08) # Move to an initial pose offset. - self._send_cmd(T_B_G * Transform.translation([0, 0, -0.05]) * self.T_G_EE) + self.send_cmd( + T_base_grasp * Transform.translation([0, 0, -0.05]) * self.T_grasp_ee + ) rospy.sleep(3.0) # Approach grasp pose. - self._send_cmd(T_B_G * self.T_G_EE) + self.send_cmd(T_base_grasp * self.T_grasp_ee) rospy.sleep(2.0) # Close the fingers. self.gripper.grasp() # Lift the object. - target = Transform.translation([0, 0, 0.2]) * T_B_G * self.T_G_EE - self._send_cmd(target) + target = Transform.translation([0, 0, 0.2]) * T_base_grasp * self.T_grasp_ee + self.send_cmd(target) rospy.sleep(2.0) # Check whether the object remains in the hand @@ -79,14 +115,14 @@ class GraspController: return "succeeded" if success else "failed" - def _postprocess(self, T_B_G): + def postprocess(self, T_base_grasp): # Ensure that the camera is pointing forward. - rot = T_B_G.rotation + rot = T_base_grasp.rotation if rot.as_matrix()[:, 0][0] < 0: - T_B_G.rotation = rot * Rotation.from_euler("z", np.pi) - return T_B_G + T_base_grasp.rotation = rot * Rotation.from_euler("z", np.pi) + return T_base_grasp - def _collect_info(self, result): + def collect_info(self, result): points = [p.translation for p in self.policy.viewpoints] d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])]) info = { diff --git a/active_grasp/policy.py b/active_grasp/policy.py index c27d5c7..49e3c3f 100644 --- a/active_grasp/policy.py +++ b/active_grasp/policy.py @@ -1,8 +1,6 @@ -import cv_bridge import numpy as np from pathlib import Path import rospy -from sensor_msgs.msg import CameraInfo, Image, PointCloud2 from .visualization import Visualizer from robot_helpers.ros import tf @@ -16,87 +14,68 @@ class Policy: def activate(self, bbox): raise NotImplementedError - def update(self): + def update(self, img, extrinsic): raise NotImplementedError class BasePolicy(Policy): - def __init__(self): - self.cv_bridge = cv_bridge.CvBridge() - self.vgn = VGN(Path(rospy.get_param("vgn/model"))) - self.finger_depth = 0.05 + def __init__(self, intrinsic): + self.intrinsic = intrinsic self.rate = 5 - self._load_parameters() - self._lookup_transforms() - self._init_camera_stream() - self._init_publishers() - self._init_visualizer() + self.load_parameters() + self.init_visualizer() - def _load_parameters(self): - self.task_frame = rospy.get_param("~frame_id") - self.base_frame = rospy.get_param("~base_frame_id") - self.ee_frame = rospy.get_param("~ee_frame_id") - self.cam_frame = rospy.get_param("~camera/frame_id") - self.info_topic = rospy.get_param("~camera/info_topic") - self.depth_topic = rospy.get_param("~camera/depth_topic") + def load_parameters(self): + self.base_frame = rospy.get_param("active_grasp/base_frame_id") + self.task_frame = "task" + self.vgn = VGN(Path(rospy.get_param("vgn/model"))) - def _lookup_transforms(self): - self.T_B_task = tf.lookup(self.base_frame, self.task_frame) - self.T_EE_cam = tf.lookup(self.ee_frame, self.cam_frame) - - def _init_camera_stream(self): - msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0)) - self.intrinsic = from_camera_info_msg(msg) - rospy.Subscriber(self.depth_topic, Image, self._sensor_cb, queue_size=1) - - def _sensor_cb(self, msg): - self.img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32) - self.extrinsic = tf.lookup(self.cam_frame, self.task_frame, msg.header.stamp) - - def _init_publishers(self): - self.scene_cloud_pub = rospy.Publisher("scene_cloud", PointCloud2, queue_size=1) - - def _init_visualizer(self): - self.visualizer = Visualizer(self.task_frame) + def init_visualizer(self): + self.visualizer = Visualizer(self.base_frame) def activate(self, bbox): self.bbox = bbox + + # Define the VGN task frame s.t. the bounding box is in its center + self.center = 0.5 * (bbox.min + bbox.max) + self.T_base_task = Transform.translation(self.center - np.full(3, 0.15)) + tf.broadcast(self.T_base_task, self.base_frame, self.task_frame) + rospy.sleep(0.1) # wait for the transform to be published + self.tsdf = UniformTSDFVolume(0.3, 40) self.viewpoints = [] self.done = False - self.best_grasp = None # grasp pose defined w.r.t. the robot's base frame + self.best_grasp = None + self.visualizer.clear() self.visualizer.bbox(bbox) - def _integrate_latest_image(self): - self.viewpoints.append(self.extrinsic.inv()) - self.tsdf.integrate( - self.img, - self.intrinsic, - self.extrinsic, - ) - self._publish_scene_cloud() + def integrate_img(self, img, extrinsic): + self.viewpoints.append(extrinsic.inv()) + self.tsdf.integrate(img, self.intrinsic, extrinsic * self.T_base_task) + self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud()) self.visualizer.path(self.viewpoints) - def _publish_scene_cloud(self): - cloud = self.tsdf.get_scene_cloud() - msg = to_cloud_msg(self.task_frame, np.asarray(cloud.points)) - self.scene_cloud_pub.publish(msg) - - def _predict_best_grasp(self): + def predict_best_grasp(self): tsdf_grid = self.tsdf.get_grid() out = self.vgn.predict(tsdf_grid) score_fn = lambda g: g.pose.translation[2] grasps = compute_grasps(self.tsdf.voxel_size, out, score_fn, max_filter_size=3) - grasps = self._select_grasps_on_target_object(grasps) - return self.T_B_task * grasps[0].pose if len(grasps) > 0 else None + grasps = self.transform_grasps_to_base_frame(grasps) + grasps = self.select_grasps_on_target_object(grasps) + return grasps[0] if len(grasps) > 0 else None - def _select_grasps_on_target_object(self, grasps): + def transform_grasps_to_base_frame(self, grasps): + for grasp in grasps: + grasp.pose = self.T_base_task * grasp.pose + return grasps + + def select_grasps_on_target_object(self, grasps): result = [] - for g in grasps: - tip = g.pose.rotation.apply([0, 0, 0.05]) + g.pose.translation + for grasp in grasps: + tip = grasp.pose.rotation.apply([0, 0, 0.05]) + grasp.pose.translation if self.bbox.is_inside(tip): - result.append(g) + result.append(grasp) return result @@ -108,8 +87,8 @@ def register(id, cls): registry[id] = cls -def make(id): +def make(id, *args, **kwargs): if id in registry: - return registry[id]() + return registry[id](*args, **kwargs) else: raise ValueError("{} policy does not exist.".format(id)) diff --git a/active_grasp/simulation.py b/active_grasp/simulation.py index 52007b6..93d7f6b 100644 --- a/active_grasp/simulation.py +++ b/active_grasp/simulation.py @@ -46,8 +46,8 @@ class Simulation: self.origin = [-0.3, -0.5 * self.length, 0.5] def load_robot(self): - self.T_W_B = Transform(Rotation.identity(), np.r_[-0.6, 0.0, 0.4]) - self.arm = BtPandaArm(self.panda_urdf, self.T_W_B) + self.T_world_base = Transform.translation(np.r_[-0.6, 0.0, 0.4]) + self.arm = BtPandaArm(self.panda_urdf, self.T_world_base) self.gripper = BtPandaGripper(self.arm) self.model = Model(self.panda_urdf, self.arm.base_frame, self.arm.ee_frame) self.camera = BtCamera(320, 240, 1.047, 0.1, 1.0, self.arm.uid, 11) @@ -135,33 +135,34 @@ class Simulation: def get_target_bbox(self, uid): aabb_min, aabb_max = p.getAABB(uid) - aabb_min = np.array(aabb_min) - self.origin - aabb_max = np.array(aabb_max) - self.origin + # Transform the coordinates to base_frame + aabb_min = np.array(aabb_min) - self.T_world_base.translation + aabb_max = np.array(aabb_max) - self.T_world_base.translation return AABBox(aabb_min, aabb_max) class CartesianPoseController: def __init__(self, model, frame, x0): - self._model = model - self._frame = frame + self.model = model + self.frame = frame self.kp = np.ones(6) * 4.0 - self.max_linear_vel = 0.2 + self.max_linear_vel = 0.1 self.max_angular_vel = 1.57 self.x_d = x0 def update(self, q): - x = self._model.pose(self._frame, q) + x = self.model.pose(self.frame, q) error = np.zeros(6) error[:3] = self.x_d.translation - x.translation error[3:] = (self.x_d.rotation * x.rotation.inv()).as_rotvec() - dx = self._limit_rate(self.kp * error) - J_pinv = np.linalg.pinv(self._model.jacobian(self._frame, q)) + dx = self.limit_rate(self.kp * error) + J_pinv = np.linalg.pinv(self.model.jacobian(self.frame, q)) cmd = np.dot(J_pinv, dx) return cmd - def _limit_rate(self, dx): + def limit_rate(self, dx): linear, angular = dx[:3], dx[3:] linear = np.clip(linear, -self.max_linear_vel, self.max_linear_vel) angular = np.clip(angular, -self.max_angular_vel, self.max_angular_vel) diff --git a/active_grasp/visualization.py b/active_grasp/visualization.py index 7b6acd5..78c3e79 100644 --- a/active_grasp/visualization.py +++ b/active_grasp/visualization.py @@ -1,15 +1,16 @@ import numpy as np import rospy - from robot_helpers.ros.rviz import * from robot_helpers.spatial import Transform +from vgn.utils import * class Visualizer: def __init__(self, frame, topic="visualization_marker_array"): self.frame = frame self.marker_pub = rospy.Publisher(topic, MarkerArray, queue_size=1) + self.scene_cloud_pub = rospy.Publisher("scene_cloud", PointCloud2, queue_size=1) def clear(self): marker = Marker(action=Marker.DELETEALL) @@ -22,6 +23,10 @@ class Visualizer: marker = create_cube_marker(self.frame, pose, scale, color, ns="bbox") self.draw([marker]) + def scene_cloud(self, frame, cloud): + msg = to_cloud_msg(frame, np.asarray(cloud.points)) + self.scene_cloud_pub.publish(msg) + def path(self, poses): color = np.r_[31, 119, 180] / 255.0 points = [p.translation for p in poses] diff --git a/config/active_grasp.yaml b/cfg/active_grasp.yaml similarity index 91% rename from config/active_grasp.yaml rename to cfg/active_grasp.yaml index e089e71..10983da 100644 --- a/config/active_grasp.yaml +++ b/cfg/active_grasp.yaml @@ -2,8 +2,6 @@ bt_sim: gui: True active_grasp: - frame_id: task - length: 0.3 base_frame_id: panda_link0 ee_frame_id: panda_hand ee_grasp_offset: [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.065] @@ -11,7 +9,7 @@ active_grasp: frame_id: camera_optical_frame info_topic: /camera/depth/camera_info depth_topic: /camera/depth/image_raw - + vgn: model: $(find vgn)/assets/models/vgn_conv.pth finger_depth: 0.05 diff --git a/launch/active_grasp.launch b/launch/active_grasp.launch index f23dbd5..69ed210 100644 --- a/launch/active_grasp.launch +++ b/launch/active_grasp.launch @@ -2,7 +2,7 @@ - + diff --git a/launch/active_grasp.rviz b/launch/active_grasp.rviz index fc77e2a..0279696 100644 --- a/launch/active_grasp.rviz +++ b/launch/active_grasp.rviz @@ -3,10 +3,9 @@ Panels: Help Height: 78 Name: Displays Property Tree Widget: - Expanded: - - /TF1/Tree1 + Expanded: ~ Splitter Ratio: 0.5 - Tree Height: 574 + Tree Height: 471 - Class: rviz/Selection Name: Selection - Class: rviz/Tool Properties @@ -137,9 +136,9 @@ Visualization Manager: Frames: All Enabled: false camera_optical_frame: - Value: true - panda_hand: Value: false + panda_hand: + Value: true panda_leftfinger: Value: false panda_link0: @@ -165,7 +164,7 @@ Visualization Manager: task: Value: true world: - Value: true + Value: false Marker Alpha: 1 Marker Scale: 0.5 Name: TF @@ -224,7 +223,7 @@ Visualization Manager: Value: true - Alpha: 1 Axes Length: 0.05000000074505806 - Axes Radius: 0.004999999888241291 + Axes Radius: 0.009999999776482582 Class: rviz/Pose Color: 255; 25; 0 Enabled: true @@ -244,6 +243,7 @@ Visualization Manager: Name: Markers Namespaces: bbox: true + path: true Queue Size: 100 Value: true Enabled: true @@ -274,7 +274,7 @@ Visualization Manager: Views: Current: Class: rviz/Orbit - Distance: 1.3517695665359497 + Distance: 1.2179546356201172 Enable Stereo Rendering: Stereo Eye Separation: 0.05999999865889549 Stereo Focal Distance: 1 @@ -282,25 +282,25 @@ Visualization Manager: Value: false Field of View: 0.7853981852531433 Focal Point: - X: 0.3073185980319977 - Y: 0.050485748797655106 - Z: 0.3944588601589203 + X: 0.2475447803735733 + Y: 0.03526053577661514 + Z: 0.4393550157546997 Focal Shape Fixed Size: true Focal Shape Size: 0.05000000074505806 Invert Z Axis: false Name: Current View Near Clip Distance: 0.009999999776482582 - Pitch: 0.4747979938983917 + Pitch: 0.2147984653711319 Target Frame: - Yaw: 5.098489761352539 + Yaw: 5.383471488952637 Saved: ~ Window Geometry: Displays: collapsed: false - Height: 871 + Height: 768 Hide Left Dock: false Hide Right Dock: true - QMainWindow State: 000000ff00000000fd000000040000000000000156000002c9fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000002c9000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f000002b0fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d000002b0000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e100000197000000030000050e0000003efc0100000002fb0000000800540069006d006501000000000000050e000002eb00fffffffb0000000800540069006d00650100000000000004500000000000000000000003b2000002c900000004000000040000000800000008fc00000002000000020000000000000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 + QMainWindow State: 000000ff00000000fd00000004000000000000015600000262fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d00000262000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f000002b0fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d000002b0000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000003e50000003efc0100000002fb0000000800540069006d00650100000000000003e5000002eb00fffffffb0000000800540069006d00650100000000000004500000000000000000000002890000026200000004000000040000000800000008fc00000002000000020000000000000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 Selection: collapsed: false Time: @@ -309,6 +309,6 @@ Window Geometry: collapsed: false Views: collapsed: true - Width: 1294 - X: 104 - Y: 374 + Width: 997 + X: 876 + Y: 127 diff --git a/scripts/bt_sim_node.py b/scripts/bt_sim_node.py index e43f32d..bf51b5a 100755 --- a/scripts/bt_sim_node.py +++ b/scripts/bt_sim_node.py @@ -20,11 +20,10 @@ class BtSimNode: def __init__(self): self.gui = rospy.get_param("~gui", True) self.sim = Simulation(gui=self.gui) - self._init_plugins() - self._advertise_services() - self._broadcast_transforms() + self.init_plugins() + self.advertise_services() - def _init_plugins(self): + def init_plugins(self): self.plugins = [ PhysicsPlugin(self.sim), JointStatePlugin(self.sim.arm, self.sim.gripper), @@ -34,20 +33,10 @@ class BtSimNode: CameraPlugin(self.sim.camera), ] - def _advertise_services(self): + def advertise_services(self): rospy.Service("seed", Seed, self.seed) rospy.Service("reset", Reset, self.reset) - def _broadcast_transforms(self): - self.static_broadcaster = tf2_ros.StaticTransformBroadcaster() - msgs = [ - to_transform_stamped_msg(self.sim.T_W_B, "world", "panda_link0"), - to_transform_stamped_msg( - Transform.translation(self.sim.origin), "world", "task" - ), - ] - self.static_broadcaster.sendTransform(msgs) - def seed(self, req): self.sim.seed(req.seed) return SeedResponse() @@ -63,10 +52,10 @@ class BtSimNode: return res def run(self): - self._start_plugins() + self.start_plugins() rospy.spin() - def _start_plugins(self): + def start_plugins(self): for plugin in self.plugins: plugin.thread.start() plugin.is_running = True @@ -77,17 +66,17 @@ class Plugin: def __init__(self, rate): self.rate = rate - self.thread = Thread(target=self._loop, daemon=True) + self.thread = Thread(target=self.loop, daemon=True) self.is_running = False - def _loop(self): + def loop(self): rate = rospy.Rate(self.rate) while not rospy.is_shutdown(): if self.is_running: - self._update() + self.update() rate.sleep() - def _update(self): + def update(self): raise NotImplementedError @@ -96,7 +85,7 @@ class PhysicsPlugin(Plugin): super().__init__(sim.rate) self.sim = sim - def _update(self): + def update(self): self.sim.step() @@ -107,7 +96,7 @@ class JointStatePlugin(Plugin): self.gripper = gripper self.pub = rospy.Publisher("joint_states", JointState, queue_size=10) - def _update(self): + def update(self): q, _ = self.arm.get_state() width = self.gripper.read() msg = JointState() @@ -125,13 +114,13 @@ class ArmControllerPlugin(Plugin): super().__init__(rate) self.arm = arm self.controller = controller - rospy.Subscriber("command", PoseStamped, self._target_cb) + rospy.Subscriber("command", PoseStamped, self.target_cb) - def _target_cb(self, msg): + def target_cb(self, msg): assert msg.header.frame_id == self.arm.base_frame self.controller.x_d = from_pose_msg(msg.pose) - def _update(self): + def update(self): q, _ = self.arm.get_state() cmd = self.controller.update(q) self.arm.set_desired_joint_velocities(cmd) @@ -142,20 +131,20 @@ class MoveActionPlugin(Plugin): super().__init__(rate) self.gripper = gripper self.dt = 1.0 / self.rate - self._init_action_server() + self.init_action_server() - def _init_action_server(self): + def init_action_server(self): name = "/franka_gripper/move" self.action_server = SimpleActionServer(name, MoveAction, auto_start=False) - self.action_server.register_goal_callback(self._action_goal_cb) + self.action_server.register_goal_callback(self.action_goal_cb) self.action_server.start() - def _action_goal_cb(self): + def action_goal_cb(self): self.elapsed_time = 0.0 goal = self.action_server.accept_new_goal() self.gripper.set_desired_width(goal.width) - def _update(self): + def update(self): if self.action_server.is_active(): self.elapsed_time += self.dt if self.elapsed_time > 1.0: @@ -167,20 +156,20 @@ class GraspActionPlugin(Plugin): super().__init__(rate) self.gripper = gripper self.dt = 1.0 / self.rate - self._init_action_server() + self.init_action_server() - def _init_action_server(self): + def init_action_server(self): name = "/franka_gripper/grasp" self.action_server = SimpleActionServer(name, GraspAction, auto_start=False) - self.action_server.register_goal_callback(self._action_goal_cb) + self.action_server.register_goal_callback(self.action_goal_cb) self.action_server.start() - def _action_goal_cb(self): + def action_goal_cb(self): self.elapsed_time = 0.0 goal = self.action_server.accept_new_goal() self.gripper.set_desired_width(goal.width) - def _update(self): + def update(self): if self.action_server.is_active(): self.elapsed_time += self.dt if self.elapsed_time > 1.0: @@ -188,21 +177,20 @@ class GraspActionPlugin(Plugin): class CameraPlugin(Plugin): - def __init__(self, camera, name="camera"): - rate = rospy.get_param("~cam_rate", 5) + def __init__(self, camera, name="camera", rate=5): super().__init__(rate) self.camera = camera self.name = name self.cv_bridge = cv_bridge.CvBridge() - self._init_publishers() + self.init_publishers() - def _init_publishers(self): + def init_publishers(self): topic = self.name + "/depth/camera_info" self.info_pub = rospy.Publisher(topic, CameraInfo, queue_size=10) topic = self.name + "/depth/image_raw" self.depth_pub = rospy.Publisher(topic, Image, queue_size=10) - def _update(self): + def update(self): stamp = rospy.Time.now() msg = to_camera_info_msg(self.camera.intrinsic) diff --git a/scripts/run.py b/scripts/run.py index c7b5213..1870b0b 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -6,10 +6,33 @@ import rospy from tqdm import tqdm from active_grasp.controller import * -from active_grasp.policy import make, registry +from active_grasp.policy import registry from active_grasp.srv import Seed +def main(): + rospy.init_node("active_grasp") + parser = create_parser() + args = parser.parse_args() + controller = GraspController(args.policy) + logger = Logger(args.logdir, args.policy) + + seed_simulation(args.seed) + + for _ in tqdm(range(args.runs)): + info = controller.run() + logger.log_run(info) + + +def create_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("policy", type=str, choices=registry.keys()) + parser.add_argument("--runs", type=int, default=10) + parser.add_argument("--logdir", type=Path, default="logs") + parser.add_argument("--seed", type=int, default=12) + return parser + + class Logger: def __init__(self, logdir, policy): stamp = datetime.now().strftime("%y%m%d-%H%M%S") @@ -21,34 +44,10 @@ class Logger: df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False) -def create_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("policy", type=str, choices=registry.keys()) - parser.add_argument("--runs", type=int, default=10) - parser.add_argument("--logdir", type=Path, default="logs") - parser.add_argument("--seed", type=int, default=12) - return parser - - def seed_simulation(seed): rospy.ServiceProxy("seed", Seed)(seed) rospy.sleep(1.0) -def main(): - rospy.init_node("active_grasp") - parser = create_parser() - args = parser.parse_args() - policy = make(args.policy) - controller = GraspController(policy) - logger = Logger(args.logdir, args.policy) - - seed_simulation(args.seed) - - for _ in tqdm(range(args.runs)): - info = controller.run() - logger.log_run(info) - - if __name__ == "__main__": main()