Publish simulation state in multiple threads

This commit is contained in:
Michel Breyer 2021-07-14 16:52:53 +02:00
parent 7d3283ff32
commit 8989115bd7
5 changed files with 124 additions and 93 deletions

View File

@ -59,7 +59,7 @@ class GraspController:
# Approach grasp pose. # Approach grasp pose.
self.controller.send_target(T_B_G * self.T_G_EE) self.controller.send_target(T_B_G * self.T_G_EE)
rospy.sleep(1.0) rospy.sleep(2.0)
# Close the fingers. # Close the fingers.
self.gripper.grasp() self.gripper.grasp()

View File

@ -19,7 +19,6 @@ class Simulation(BtSim):
self.load_table() self.load_table()
self.load_robot() self.load_robot()
self.load_controller() self.load_controller()
self.reset()
def configure_visualizer(self): def configure_visualizer(self):
# p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0) # p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)
@ -45,7 +44,9 @@ class Simulation(BtSim):
self.camera = BtCamera(320, 240, 1.047, 0.1, 1.0, self.arm.uid, 11) self.camera = BtCamera(320, 240, 1.047, 0.1, 1.0, self.arm.uid, 11)
def load_controller(self): def load_controller(self):
self.controller = CartesianPoseController(self.model, self.arm.ee_frame, None) q, _ = self.arm.get_state()
x0 = self.model.pose(self.arm.ee_frame, q)
self.controller = CartesianPoseController(self.model, self.arm.ee_frame, x0)
def reset(self): def reset(self):
self.remove_all_objects() self.remove_all_objects()

View File

@ -1,6 +1,6 @@
<?xml version="1.0" ?> <?xml version="1.0" ?>
<launch> <launch>
<arg name="launch_rviz" default="true" /> <arg name="launch_rviz" default="false" />
<rosparam command="load" file="$(find active_grasp)launch/active_grasp.yaml" subst_value="true" /> <rosparam command="load" file="$(find active_grasp)launch/active_grasp.yaml" subst_value="true" />

View File

@ -1,7 +1,6 @@
bt_sim: bt_sim:
gui: True gui: True
seed: 12 seed: 12
cam_pub_rate: 10
active_grasp: active_grasp:
frame_id: task frame_id: task
@ -10,9 +9,9 @@ active_grasp:
ee_frame_id: panda_hand ee_frame_id: panda_hand
ee_grasp_offset: [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.065] ee_grasp_offset: [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.065]
camera: camera:
frame_id: cam_optical_frame frame_id: camera_optical_frame
info_topic: /cam/depth/camera_info info_topic: /camera/depth/camera_info
depth_topic: /cam/depth/image_raw depth_topic: /camera/depth/image_raw
vgn: vgn:
model: $(find vgn)/assets/models/vgn_conv.pth model: $(find vgn)/assets/models/vgn_conv.pth

View File

@ -1,12 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import actionlib from actionlib import SimpleActionServer
import cv_bridge import cv_bridge
import franka_gripper.msg from franka_gripper.msg import *
from geometry_msgs.msg import PoseStamped from geometry_msgs.msg import PoseStamped
import numpy as np import numpy as np
import rospy import rospy
from sensor_msgs.msg import JointState, Image, CameraInfo from sensor_msgs.msg import JointState, Image, CameraInfo
from threading import Thread
import tf2_ros import tf2_ros
from active_grasp.srv import Reset, ResetResponse from active_grasp.srv import Reset, ResetResponse
@ -17,31 +18,29 @@ from robot_utils.ros.conversions import *
class BtSimNode: class BtSimNode:
def __init__(self): def __init__(self):
self.load_parameters()
rng = self.get_rng()
self.sim = Simulation(gui=self.gui, rng=rng)
self.robot_state_interface = RobotStateInterface(self.sim.arm, self.sim.gripper)
self.arm_interface = ArmInterface(self.sim.arm, self.sim.controller)
self.gripper_interface = GripperInterface(self.sim.gripper)
self.camera_interface = CameraInterface(self.sim.camera)
self.step_cnt = 0
self.reset_requested = False
self.advertise_services()
self.broadcast_transforms()
def load_parameters(self):
self.gui = rospy.get_param("~gui", True) self.gui = rospy.get_param("~gui", True)
self.cam_pub_rate = rospy.get_param("~cam_pub_rate")
def get_rng(self):
seed = rospy.get_param("~seed", None) seed = rospy.get_param("~seed", None)
return np.random.default_rng(seed) if seed else np.random
def advertise_services(self): rng = np.random.default_rng(seed) if seed else np.random
self.sim = Simulation(gui=self.gui, rng=rng)
self._init_plugins()
self._advertise_services()
self._broadcast_transforms()
def _init_plugins(self):
self.plugins = [
PhysicsPlugin(self.sim),
JointStatePlugin(self.sim.arm, self.sim.gripper),
ArmControllerPlugin(self.sim.arm, self.sim.controller),
GripperControllerPlugin(self.sim.gripper),
CameraPlugin(self.sim.camera),
]
def _advertise_services(self):
rospy.Service("reset", Reset, self.reset) rospy.Service("reset", Reset, self.reset)
def broadcast_transforms(self): def _broadcast_transforms(self):
self.static_broadcaster = tf2_ros.StaticTransformBroadcaster() self.static_broadcaster = tf2_ros.StaticTransformBroadcaster()
msgs = [ msgs = [
to_transform_stamped_msg(self.sim.T_W_B, "world", "panda_link0"), to_transform_stamped_msg(self.sim.T_W_B, "world", "panda_link0"),
@ -52,38 +51,63 @@ class BtSimNode:
self.static_broadcaster.sendTransform(msgs) self.static_broadcaster.sendTransform(msgs)
def reset(self, req): def reset(self, req):
self.reset_requested = True for plugin in self.plugins:
rospy.sleep(1.0) # wait for the latest sim step to finish plugin.is_running = False
rospy.sleep(1.0) # TODO replace with a read-write lock
bbox = self.sim.reset() bbox = self.sim.reset()
res = ResetResponse(to_bbox_msg(bbox)) res = ResetResponse(to_bbox_msg(bbox))
self.step_cnt = 0
self.reset_requested = False for plugin in self.plugins:
plugin.is_running = True
return res return res
def run(self): def run(self):
rate = rospy.Rate(self.sim.rate) self._start_plugins()
rospy.spin()
def _start_plugins(self):
for plugin in self.plugins:
plugin.thread.start()
plugin.is_running = True
class Plugin:
"""A plugin that spins at a constant rate in its own thread."""
def __init__(self, rate):
self.rate = rate
self.thread = Thread(target=self._loop, daemon=True)
self.is_running = False
def _loop(self):
rate = rospy.Rate(self.rate)
while not rospy.is_shutdown(): while not rospy.is_shutdown():
if not self.reset_requested: if self.is_running:
self.handle_updates() self._update()
self.sim.step()
self.step_cnt = (self.step_cnt + 1) % self.sim.rate
rate.sleep() rate.sleep()
def handle_updates(self): def _update(self):
self.robot_state_interface.update() raise NotImplementedError
self.arm_interface.update()
self.gripper_interface.update(self.sim.dt)
if self.step_cnt % int(self.sim.rate / self.cam_pub_rate) == 0:
self.camera_interface.update()
class RobotStateInterface: class PhysicsPlugin(Plugin):
def __init__(self, arm, gripper): def __init__(self, sim):
super().__init__(sim.rate)
self.sim = sim
def _update(self):
self.sim.step()
class JointStatePlugin(Plugin):
def __init__(self, arm, gripper, rate=30):
super().__init__(rate)
self.arm = arm self.arm = arm
self.gripper = gripper self.gripper = gripper
self.joint_pub = rospy.Publisher("joint_states", JointState, queue_size=10) self.pub = rospy.Publisher("joint_states", JointState, queue_size=10)
def update(self): def _update(self):
q, _ = self.arm.get_state() q, _ = self.arm.get_state()
width = self.gripper.read() width = self.gripper.read()
msg = JointState() msg = JointState()
@ -93,86 +117,93 @@ class RobotStateInterface:
"panda_finger_joint2", "panda_finger_joint2",
] ]
msg.position = np.r_[q, 0.5 * width, 0.5 * width] msg.position = np.r_[q, 0.5 * width, 0.5 * width]
self.joint_pub.publish(msg) self.pub.publish(msg)
class ArmInterface: class ArmControllerPlugin(Plugin):
def __init__(self, arm, controller): def __init__(self, arm, controller, rate=30):
super().__init__(rate)
self.arm = arm self.arm = arm
self.controller = controller self.controller = controller
rospy.Subscriber("command", PoseStamped, self.target_cb) rospy.Subscriber("command", PoseStamped, self._target_cb)
def update(self): def _target_cb(self, msg):
assert msg.header.frame_id == self.arm.base_frame
self.controller.x_d = from_pose_msg(msg.pose)
def _update(self):
q, _ = self.arm.get_state() q, _ = self.arm.get_state()
cmd = self.controller.update(q) cmd = self.controller.update(q)
self.arm.set_desired_joint_velocities(cmd) self.arm.set_desired_joint_velocities(cmd)
def target_cb(self, msg):
assert msg.header.frame_id == self.arm.base_frame
self.controller.x_d = from_pose_msg(msg.pose)
class GripperControllerPlugin(Plugin):
class GripperInterface: def __init__(self, gripper, rate=10):
def __init__(self, gripper): super().__init__(rate)
self.gripper = gripper self.gripper = gripper
self.move_server = actionlib.SimpleActionServer( self.dt = 1.0 / self.rate
"/franka_gripper/move", self._init_move_action_server()
franka_gripper.msg.MoveAction, self._init_grasp_action_server()
auto_start=False,
) def _init_move_action_server(self):
self.move_server.register_goal_callback(self.move_action_goal_cb) name = "/franka_gripper/move"
self.move_server = SimpleActionServer(name, MoveAction, auto_start=False)
self.move_server.register_goal_callback(self._move_action_goal_cb)
self.move_server.start() self.move_server.start()
self.grasp_server = actionlib.SimpleActionServer( def _init_grasp_action_server(self):
"/franka_gripper/grasp", name = "/franka_gripper/grasp"
franka_gripper.msg.GraspAction, self.grasp_server = SimpleActionServer(name, GraspAction, auto_start=False)
auto_start=False, self.grasp_server.register_goal_callback(self._grasp_action_goal_cb)
)
self.grasp_server.register_goal_callback(self.grasp_action_goal_cb)
self.grasp_server.start() self.grasp_server.start()
def move_action_goal_cb(self): def _move_action_goal_cb(self):
self.elapsed_time_since_move_action_goal = 0.0 self.elapsed_time_since_move_action_goal = 0.0
goal = self.move_server.accept_new_goal() goal = self.move_server.accept_new_goal()
self.gripper.set_desired_width(goal.width) self.gripper.set_desired_width(goal.width)
def grasp_action_goal_cb(self): def _grasp_action_goal_cb(self):
self.elapsed_time_since_grasp_action_goal = 0.0 self.elapsed_time_since_grasp_action_goal = 0.0
goal = self.grasp_server.accept_new_goal() goal = self.grasp_server.accept_new_goal()
self.gripper.set_desired_width(goal.width) self.gripper.set_desired_width(goal.width)
def update(self, dt): def _update(self):
if self.move_server.is_active(): if self.move_server.is_active():
self.elapsed_time_since_move_action_goal += dt self.elapsed_time_since_move_action_goal += self.dt
if self.elapsed_time_since_move_action_goal > 1.0: if self.elapsed_time_since_move_action_goal > 1.0:
self.move_server.set_succeeded() self.move_server.set_succeeded()
if self.grasp_server.is_active(): if self.grasp_server.is_active():
self.elapsed_time_since_grasp_action_goal += dt self.elapsed_time_since_grasp_action_goal += self.dt
if self.elapsed_time_since_grasp_action_goal > 1.0: if self.elapsed_time_since_grasp_action_goal > 1.0:
self.grasp_server.set_succeeded() self.grasp_server.set_succeeded()
class CameraInterface: class CameraPlugin(Plugin):
def __init__(self, camera): def __init__(self, camera, name="camera", rate=10):
super().__init__(rate)
self.camera = camera self.camera = camera
self.name = name
self.cv_bridge = cv_bridge.CvBridge() self.cv_bridge = cv_bridge.CvBridge()
self.cam_info_msg = to_camera_info_msg(self.camera.intrinsic) self._init_publishers()
self.cam_info_msg.header.frame_id = "cam_optical_frame"
self.cam_info_pub = rospy.Publisher(
"/cam/depth/camera_info",
CameraInfo,
queue_size=10,
)
self.depth_pub = rospy.Publisher("/cam/depth/image_raw", Image, queue_size=10)
def update(self): def _init_publishers(self):
topic = self.name + "/depth/camera_info"
self.info_pub = rospy.Publisher(topic, CameraInfo, queue_size=10)
topic = self.name + "/depth/image_raw"
self.depth_pub = rospy.Publisher(topic, Image, queue_size=10)
def _update(self):
stamp = rospy.Time.now() stamp = rospy.Time.now()
self.cam_info_msg.header.stamp = stamp
self.cam_info_pub.publish(self.cam_info_msg) msg = to_camera_info_msg(self.camera.intrinsic)
msg.header.frame_id = self.name + "_optical_frame"
msg.header.stamp = stamp
self.info_pub.publish(msg)
img = self.camera.get_image() img = self.camera.get_image()
depth_msg = self.cv_bridge.cv2_to_imgmsg(img.depth) msg = self.cv_bridge.cv2_to_imgmsg(img.depth)
depth_msg.header.stamp = stamp msg.header.stamp = stamp
self.depth_pub.publish(depth_msg) self.depth_pub.publish(msg)
def main(): def main():