245 lines
11 KiB
Python
245 lines
11 KiB
Python
|
import re
|
||
|
import os
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import pickle
|
||
|
import asyncio
|
||
|
|
||
|
from concurrent.futures import ThreadPoolExecutor
|
||
|
from scipy.spatial.transform import Rotation as R
|
||
|
import trimesh
|
||
|
|
||
|
from utils.pose_util import PoseUtil
|
||
|
import annotations.stereotype as stereotype
|
||
|
from configs.config import ConfigManager
|
||
|
from utils.view_util import ViewUtil
|
||
|
from configs.config import ConfigManager
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
class OmniDataConverter():
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
raise Exception("Utility class can NOT be instantiated")
|
||
|
|
||
|
@staticmethod
|
||
|
def convert_rgb(old_rgb):
|
||
|
return old_rgb
|
||
|
|
||
|
@staticmethod
|
||
|
def convert_depth(old_depth):
|
||
|
return old_depth
|
||
|
|
||
|
@staticmethod
|
||
|
def convert_mask(old_mask, old_mask_label, object_name):
|
||
|
target_mask_id = None
|
||
|
for key, value in old_mask_label.items():
|
||
|
if value == object_name:
|
||
|
target_mask_id = int(key)
|
||
|
break
|
||
|
if target_mask_id is None:
|
||
|
raise Exception("Object name not found in the mask labels")
|
||
|
target_mask = (old_mask == target_mask_id)
|
||
|
return target_mask
|
||
|
|
||
|
@staticmethod
|
||
|
def convert_mesh(mesh):
|
||
|
object_model_scale = [0.001, 0.001, 0.001]
|
||
|
mesh.apply_scale(object_model_scale)
|
||
|
return mesh
|
||
|
|
||
|
@staticmethod
|
||
|
def convert_gt_pose(scene_data, object_name, cam_pose):
|
||
|
pos = scene_data[object_name]["position"]
|
||
|
quat = scene_data[object_name]["rotation"]
|
||
|
rot = R.from_quat(quat).as_matrix()
|
||
|
obj_pose = np.eye(4)
|
||
|
obj_pose[:3, :3] = rot
|
||
|
obj_pose[:3, 3] = pos
|
||
|
obj_cam_pose = np.linalg.inv(cam_pose.cpu()) @ obj_pose
|
||
|
return np.asarray(obj_cam_pose)
|
||
|
|
||
|
def convert_data(object_name, scene_data, cam_pose, rgb, depth, seg, seg_labels, camera_params):
|
||
|
rgb = OmniDataConverter.convert_rgb(rgb)
|
||
|
depth = OmniDataConverter.convert_depth(depth)
|
||
|
mask = OmniDataConverter.convert_mask(seg, seg_labels, object_name)
|
||
|
K = np.array([[camera_params["fx"], 0, camera_params["cx"]], [0, camera_params["fy"], camera_params["cy"]], [0, 0, 1]])
|
||
|
gt_pose = OmniDataConverter.convert_gt_pose(scene_data, object_name,cam_pose)
|
||
|
return K, rgb, depth, mask, gt_pose, object_name
|
||
|
|
||
|
def get_mesh(obj_name, source):
|
||
|
data_dir = ConfigManager.get("datasets", "general", "data_dir")
|
||
|
class_name = obj_name[:-4]
|
||
|
mesh_path = os.path.join(data_dir,source,"objects",class_name, obj_name,"Scan","Simp.obj") # TODO: to be changed
|
||
|
mesh = trimesh.load(mesh_path)
|
||
|
mesh = OmniDataConverter.convert_mesh(mesh)
|
||
|
return mesh
|
||
|
|
||
|
def get_scene_data(scene_name,source, data_type,):
|
||
|
data_dir = ConfigManager.get("datasets", "general", "data_dir")
|
||
|
scene_data_path = os.path.join(data_dir,source,data_type,scene_name,"scene.pickle")
|
||
|
with open(scene_data_path, "rb") as f:
|
||
|
scene_data = pickle.load(f)
|
||
|
return scene_data
|
||
|
|
||
|
def get_transformed_mat(src_mat, delta_rot,target_center_w):
|
||
|
src_rot = src_mat[:3, :3]
|
||
|
dst_rot = src_rot @ delta_rot.T
|
||
|
dst_mat = torch.eye(4).to(dst_rot.device)
|
||
|
dst_mat[:3, :3] = dst_rot
|
||
|
distance = torch.norm(target_center_w - src_mat[:3, 3])
|
||
|
z_axis_camera = dst_rot[:3, 2].reshape(-1)
|
||
|
new_camera_position_w = target_center_w - distance * z_axis_camera
|
||
|
dst_mat[:3, 3] = new_camera_position_w
|
||
|
return dst_mat
|
||
|
|
||
|
def get_score_from_data_list(data_list,source):
|
||
|
avg_adds = 0
|
||
|
data_dict = {}
|
||
|
for K, rgb, depth, mask, gt_pose, object_name in data_list:
|
||
|
if object_name not in data_dict:
|
||
|
mesh = get_mesh(object_name,source)
|
||
|
data_dict[object_name] = {"K": K, "mesh": mesh, "rgb": [rgb], "depth": [depth], "mask": [mask], "gt_pose": [gt_pose]}
|
||
|
else:
|
||
|
data_dict[object_name]["rgb"].append(rgb)
|
||
|
data_dict[object_name]["depth"].append(depth)
|
||
|
data_dict[object_name]["mask"].append(mask)
|
||
|
data_dict[object_name]["gt_pose"].append(gt_pose)
|
||
|
|
||
|
for object_name in data_dict:
|
||
|
|
||
|
K = data_dict[object_name]["K"]
|
||
|
mesh = data_dict[object_name]["mesh"]
|
||
|
|
||
|
rgb_batch = np.stack(data_dict[object_name]["rgb"])
|
||
|
depth_batch = np.stack(data_dict[object_name]["depth"])
|
||
|
mask_batch = np.stack(data_dict[object_name]["mask"])
|
||
|
gt_pose_batch = np.stack(data_dict[object_name]["gt_pose"])
|
||
|
_, results_batch = ViewUtil.get_object_pose_batch(K, mesh, rgb_batch, depth_batch, mask_batch, gt_pose_batch,11111) # TODO: port number should be variable
|
||
|
print("object_name:",object_name, "length:",len(gt_pose_batch),len(results_batch))
|
||
|
for result in results_batch:
|
||
|
avg_adds += result["ADD-S"]
|
||
|
avg_adds /= len(data_list)
|
||
|
return avg_adds
|
||
|
|
||
|
async def async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list,
|
||
|
all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port):
|
||
|
|
||
|
all_src_view_data_list = []
|
||
|
all_part_gt_dst_view_data_list = []
|
||
|
all_full_gt_dst_view_data_list = []
|
||
|
all_est_dst_view_data_list = []
|
||
|
|
||
|
with ThreadPoolExecutor() as executor:
|
||
|
loop = asyncio.get_event_loop()
|
||
|
for i in tqdm(range(total), desc="----Processing items", ncols=100):
|
||
|
src_mat = all_src_mat_list[i]
|
||
|
part_gt_dst_mat = all_part_gt_dst_mat_list[i]
|
||
|
full_gt_dst_mat = all_full_gt_dst_mat_list[i]
|
||
|
est_dst_mat = all_est_dst_mat_list[i]
|
||
|
source = all_source_list[i]
|
||
|
data_type = all_data_type_list[i]
|
||
|
scene_name = all_scene_name_list[i]
|
||
|
obj_name = all_object_name_list[i]
|
||
|
|
||
|
src_view_future = loop.run_in_executor(executor, ViewUtil.get_view, src_mat, source, data_type, scene_name, web_server_port)
|
||
|
part_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, part_gt_dst_mat, source, data_type, scene_name, web_server_port + 1)
|
||
|
full_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, full_gt_dst_mat, source, data_type, scene_name, web_server_port + 2)
|
||
|
est_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, est_dst_mat, source, data_type, scene_name, web_server_port + 3)
|
||
|
|
||
|
src_view_data, part_gt_dst_view_data, full_gt_dst_view_data, est_dst_view_data = await asyncio.gather(
|
||
|
src_view_future, part_gt_dst_view_future, full_gt_dst_view_future, est_dst_view_future
|
||
|
)
|
||
|
|
||
|
scene_data = get_scene_data(scene_name,source, data_type)
|
||
|
|
||
|
all_src_view_data_list.append(convert_data(obj_name, scene_data, src_mat, *src_view_data))
|
||
|
all_part_gt_dst_view_data_list.append(convert_data(obj_name, scene_data, part_gt_dst_mat, *part_gt_dst_view_data))
|
||
|
all_full_gt_dst_view_data_list.append(convert_data(obj_name, scene_data, full_gt_dst_mat, *full_gt_dst_view_data))
|
||
|
all_est_dst_view_data_list.append(convert_data(obj_name, scene_data, est_dst_mat, *est_dst_view_data))
|
||
|
|
||
|
return (all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list)
|
||
|
|
||
|
@stereotype.evaluation_method("object_pose_improvement")
|
||
|
def evaluate(output_list, data_list):
|
||
|
web_server_port = ConfigManager.get("settings", "experiment", "web_api", "port")
|
||
|
all_src_mat_list = []
|
||
|
all_part_gt_dst_mat_list = []
|
||
|
all_full_gt_dst_mat_list = []
|
||
|
all_est_dst_mat_list = []
|
||
|
all_scene_name_list = []
|
||
|
all_object_name_list = []
|
||
|
all_source_list = []
|
||
|
all_data_type_list = []
|
||
|
all_target_center_w_list = []
|
||
|
for output, data in zip(output_list, data_list):
|
||
|
gt_delta_rot_6d_list = data["delta_rot_6d"]
|
||
|
est_delta_rot_6d_list = output["estimated_delta_rot_6d"]
|
||
|
src_mat_list = data["src_transform"]
|
||
|
gt_mat_list = data["dst_transform"]
|
||
|
scene_name_list = data["scene_name"]
|
||
|
object_name_list = data["target_name"]
|
||
|
target_pts_list = data["target_pts"]
|
||
|
source_list = data["source"]
|
||
|
data_type_list = data["data_type"]
|
||
|
target_center_c_list = torch.mean(target_pts_list, axis=1)
|
||
|
target_center_w_list = torch.bmm(src_mat_list[:,:3,:3], target_center_c_list.unsqueeze(2)).squeeze(2) + src_mat_list[:, :3, 3]
|
||
|
gt_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_delta_rot_6d_list)
|
||
|
est_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(est_delta_rot_6d_list)
|
||
|
for i in range(len(scene_name_list)):
|
||
|
src_mat = src_mat_list[i]
|
||
|
target_center_w = target_center_w_list[i]
|
||
|
gt_delta_rot_mat = gt_delta_rot_mat_list[i]
|
||
|
est_delta_rot_mat = est_delta_rot_mat_list[i]
|
||
|
part_gt_dst_mat = get_transformed_mat(src_mat, gt_delta_rot_mat,target_center_w)
|
||
|
est_dst_mat = get_transformed_mat(src_mat, est_delta_rot_mat,target_center_w)
|
||
|
all_src_mat_list.append(src_mat)
|
||
|
all_part_gt_dst_mat_list.append(part_gt_dst_mat)
|
||
|
all_full_gt_dst_mat_list.append(gt_mat_list[i])
|
||
|
all_est_dst_mat_list.append(est_dst_mat)
|
||
|
all_scene_name_list.append(scene_name_list[i])
|
||
|
all_object_name_list.append(object_name_list[i])
|
||
|
all_source_list.append(source_list[i])
|
||
|
all_data_type_list.append(data_type_list[i])
|
||
|
all_target_center_w_list.append(target_center_w)
|
||
|
|
||
|
all_src_view_data_list = []
|
||
|
all_part_gt_dst_view_data_list = []
|
||
|
all_full_gt_dst_view_data_list = []
|
||
|
all_est_dst_view_data_list = []
|
||
|
source = all_source_list[0]
|
||
|
total = len(all_src_mat_list)
|
||
|
|
||
|
loop = asyncio.get_event_loop()
|
||
|
all_view_data_list = loop.run_until_complete(async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list,
|
||
|
all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port))
|
||
|
all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list = all_view_data_list
|
||
|
|
||
|
src_score = get_score_from_data_list(all_src_view_data_list,source)
|
||
|
part_gt_dst_score = get_score_from_data_list(all_part_gt_dst_view_data_list,source)
|
||
|
full_gt_dst_score = get_score_from_data_list(all_full_gt_dst_view_data_list,source)
|
||
|
est_dst_score = get_score_from_data_list(all_est_dst_view_data_list,source)
|
||
|
|
||
|
score_improvement = est_dst_score - src_score
|
||
|
score_diff_to_full_gt = full_gt_dst_score - est_dst_score
|
||
|
score_diff_to_part_gt = part_gt_dst_score - est_dst_score
|
||
|
look_at_center_score_diff = full_gt_dst_score - part_gt_dst_score
|
||
|
|
||
|
results = {
|
||
|
"scalars": {
|
||
|
"object_pose_score": {
|
||
|
"src": src_score,
|
||
|
"part_gt_dst": part_gt_dst_score,
|
||
|
"full_gt_dst": full_gt_dst_score,
|
||
|
"est_dst": est_dst_score,
|
||
|
},
|
||
|
"object_pose_score_improvement": score_improvement,
|
||
|
"object_pose_score_diff_to_full_gt": score_diff_to_full_gt,
|
||
|
"object_pose_score_diff_to_part_gt": score_diff_to_part_gt,
|
||
|
"object_pose_look_at_center_score_diff": look_at_center_score_diff,
|
||
|
}
|
||
|
}
|
||
|
return results
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
pass
|