Log metrics
This commit is contained in:
parent
1375cedcb5
commit
66cbf39516
1
.gitignore
vendored
1
.gitignore
vendored
@ -132,3 +132,4 @@ dmypy.json
|
|||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
assets/
|
assets/
|
||||||
|
logs/
|
||||||
|
@ -20,9 +20,11 @@ class GraspController:
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
bbox = self.reset()
|
bbox = self.reset()
|
||||||
|
with Timer("exploration_time"):
|
||||||
grasp = self.explore(bbox)
|
grasp = self.explore(bbox)
|
||||||
if grasp:
|
with Timer("execution_time"):
|
||||||
self.execute_grasp(grasp)
|
res = self.execute_grasp(grasp)
|
||||||
|
return self.collect_info(res)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
req = ResetRequest()
|
req = ResetRequest()
|
||||||
@ -42,6 +44,9 @@ class GraspController:
|
|||||||
return self.policy.best_grasp
|
return self.policy.best_grasp
|
||||||
|
|
||||||
def execute_grasp(self, grasp):
|
def execute_grasp(self, grasp):
|
||||||
|
if not grasp:
|
||||||
|
return "aborted"
|
||||||
|
|
||||||
T_B_G = self.postprocess(grasp)
|
T_B_G = self.postprocess(grasp)
|
||||||
|
|
||||||
self.gripper.move(0.08)
|
self.gripper.move(0.08)
|
||||||
@ -65,7 +70,9 @@ class GraspController:
|
|||||||
rospy.sleep(2.0)
|
rospy.sleep(2.0)
|
||||||
|
|
||||||
# Check whether the object remains in the hand
|
# Check whether the object remains in the hand
|
||||||
return self.gripper.read() > 0.005
|
success = self.gripper.read() > 0.005
|
||||||
|
|
||||||
|
return "succeeded" if success else "failed"
|
||||||
|
|
||||||
def postprocess(self, T_B_G):
|
def postprocess(self, T_B_G):
|
||||||
# Ensure that the camera is pointing forward.
|
# Ensure that the camera is pointing forward.
|
||||||
@ -73,3 +80,16 @@ class GraspController:
|
|||||||
if rot.as_matrix()[:, 0][0] < 0:
|
if rot.as_matrix()[:, 0][0] < 0:
|
||||||
T_B_G.rotation = rot * Rotation.from_euler("z", np.pi)
|
T_B_G.rotation = rot * Rotation.from_euler("z", np.pi)
|
||||||
return T_B_G
|
return T_B_G
|
||||||
|
|
||||||
|
def collect_info(self, result):
|
||||||
|
points = [p.translation for p in self.policy.viewpoints]
|
||||||
|
d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])])
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"result": result,
|
||||||
|
"viewpoint_count": len(points),
|
||||||
|
"distance_travelled": d,
|
||||||
|
}
|
||||||
|
info.update(self.policy.info)
|
||||||
|
info.update(Timer.timers)
|
||||||
|
return info
|
||||||
|
@ -28,6 +28,7 @@ class BasePolicy:
|
|||||||
self.connect_to_rviz()
|
self.connect_to_rviz()
|
||||||
|
|
||||||
self.rate = 5
|
self.rate = 5
|
||||||
|
self.info = {}
|
||||||
|
|
||||||
def load_parameters(self):
|
def load_parameters(self):
|
||||||
self.task_frame = rospy.get_param("~frame_id")
|
self.task_frame = rospy.get_param("~frame_id")
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
from datetime import datetime
|
||||||
from geometry_msgs.msg import PoseStamped
|
from geometry_msgs.msg import PoseStamped
|
||||||
|
import pandas as pd
|
||||||
import rospy
|
import rospy
|
||||||
|
import time
|
||||||
|
|
||||||
import active_grasp.msg
|
import active_grasp.msg
|
||||||
from robot_utils.ros.conversions import *
|
from robot_utils.ros.conversions import *
|
||||||
@ -34,3 +37,35 @@ def to_bbox_msg(bbox):
|
|||||||
msg.min = to_point_msg(bbox.min)
|
msg.min = to_point_msg(bbox.min)
|
||||||
msg.max = to_point_msg(bbox.max)
|
msg.max = to_point_msg(bbox.max)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
timers = dict()
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc_info):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.tic = time.perf_counter()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
elapsed_time = time.perf_counter() - self.tic
|
||||||
|
self.timers[self.name] = elapsed_time
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
def __init__(self, logdir, policy, desc):
|
||||||
|
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
name = "{}_policy={},{}".format(stamp, policy, desc).strip(",")
|
||||||
|
self.path = logdir / (name + ".csv")
|
||||||
|
|
||||||
|
def log_run(self, info):
|
||||||
|
df = pd.DataFrame.from_records([info])
|
||||||
|
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
import rospy
|
import rospy
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from active_grasp.controller import GraspController
|
from active_grasp.controller import *
|
||||||
from active_grasp.policy import make, registry
|
from active_grasp.policy import make, registry
|
||||||
|
|
||||||
|
|
||||||
def create_parser():
|
def create_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--policy", type=str, choices=registry.keys())
|
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||||
|
parser.add_argument("--runs", type=int, default=10)
|
||||||
|
parser.add_argument("--logdir", type=Path, default="logs")
|
||||||
|
parser.add_argument("--desc", type=str, default="")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -17,9 +22,11 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
policy = make(args.policy)
|
policy = make(args.policy)
|
||||||
controller = GraspController(policy)
|
controller = GraspController(policy)
|
||||||
|
logger = Logger(args.logdir, args.policy, args.desc)
|
||||||
|
|
||||||
while True:
|
for _ in tqdm(range(args.runs)):
|
||||||
controller.run()
|
info = controller.run()
|
||||||
|
logger.log_run(info)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user