nbv_grasping/utils/metric_util.py
2024-10-09 16:13:22 +00:00

125 lines
5.4 KiB
Python
Executable File

import numpy as np
class MetricUtil:
@staticmethod
def rotate_around(axis, angle_deg):
angle = angle_deg * np.pi / 180
if axis == "x":
return np.array([[1, 0, 0],
[0, np.cos(angle), -np.sin(angle)],
[0, np.sin(angle), np.cos(angle)]])
elif axis == "y":
return np.array([[np.cos(angle), 0, np.sin(angle)],
[0, 1, 0],
[-np.sin(angle), 0, np.cos(angle)]])
elif axis == "z":
return np.array([[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1]])
else:
raise ValueError("Invalid axis")
@staticmethod
def basic_rot_diff(r0, r1):
mat_diff = np.matmul(r0, r1.swapaxes(-1, -2))
diff = np.trace(mat_diff) - 1
return np.arccos(np.clip(diff / 2.0, a_min=-1.0, a_max=1.0))
@staticmethod
def axis_rot_diff(r0, r1, axis):
axis1, axis2 = r0[..., axis], r1[..., axis]
diff = np.sum(axis1 * axis2, axis=-1)
return np.arccos(np.clip(diff, a_min=-1.0, a_max=1.0))
@staticmethod
def turn_rot_diff(r0, r1, axis, turn_degrees):
diffs = []
for i in turn_degrees:
rotation_matrix = MetricUtil.rotate_around(axis, i)
diffs.append(MetricUtil.basic_rot_diff(np.matmul(r0, rotation_matrix), r1))
return np.min(diffs, axis=0)
@staticmethod
def rot_diff_rad(r0, r1, sym):
axis_map = {0: "x", 1: "y", 2: "z"}
if sym is None or sym == 0: # no symmetry
return MetricUtil.basic_rot_diff(r0, r1)
elif sym in [1, 2, 3]: # free rotation around axis
return MetricUtil.axis_rot_diff(r0, r1, sym - 1)
else: # symmetry
turns = 0
axis_idx = 0
if sym in [4, 5, 6]: # half turn
axis_idx = sym - 4
turns = 2
elif sym in [7, 8, 9]: # quarter turn
axis_idx = sym - 7
turns = 4
turn_degrees = np.arange(0, 360, 360 / turns)
return MetricUtil.turn_rot_diff(r0, r1, axis_map[axis_idx], turn_degrees)
@staticmethod
def collect_metric(pred_pose_mat, gt_pose_mat, sym):
pred_rot_mat = pred_pose_mat[:, :3, :3]
gt_rot_mat = gt_pose_mat[:, :3, :3]
pred_trans = pred_pose_mat[:, :3, 3]
gt_trans = gt_pose_mat[:, :3, 3]
trans_error = []
rot_error = []
for i in range(pred_rot_mat.shape[0]):
tdiff = np.linalg.norm(pred_trans[i] - gt_trans[i], ord=2) * 100
rdiff = MetricUtil.rot_diff_rad(pred_rot_mat[i], gt_rot_mat[i], sym[i]) / np.pi * 180.0
trans_error.append(tdiff)
rot_error.append(rdiff)
rot_error = {
'mean': np.mean(rot_error),
'median': np.median(rot_error),
'item': rot_error,
}
trans_error = {
'mean': np.mean(trans_error),
'median': np.median(trans_error),
'item': trans_error,
}
error = {'rot_error': rot_error,
'trans_error': trans_error}
return error
# -------------- Debug ---------------
def test_MetricUtil():
print("test case 0: no rotation")
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 0) * 180 / np.pi)
print("test case 1: 29 degree rotation around x-axis")
rotation_matrix = MetricUtil.rotate_around("x", 29)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 1) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi)
print("test case 2: 90 degree rotation around y-axis")
rotation_matrix = MetricUtil.rotate_around("y", 90)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 2) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi)
print("test case 3: 60 degree rotation around y-axis")
rotation_matrix = MetricUtil.rotate_around("y", 60)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 2) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi)
print("test case 4: 78 degree rotation around z-axis and 60 degree rotation around x-axis")
rotation_matrix = MetricUtil.rotate_around("z", 78) @ MetricUtil.rotate_around("x", 60)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 2) * 180 / np.pi)
print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi)
if __name__ == "__main__":
pass
test_MetricUtil()