import re import os import torch import numpy as np import pickle import asyncio from concurrent.futures import ThreadPoolExecutor from scipy.spatial.transform import Rotation as R import trimesh from utils.pose_util import PoseUtil import annotations.stereotype as stereotype from configs.config import ConfigManager from utils.view_util import ViewUtil from configs.config import ConfigManager from tqdm import tqdm class OmniDataConverter(): def __init__(self) -> None: raise Exception("Utility class can NOT be instantiated") @staticmethod def convert_rgb(old_rgb): return old_rgb @staticmethod def convert_depth(old_depth): return old_depth @staticmethod def convert_mask(old_mask, old_mask_label, object_name): target_mask_id = None for key, value in old_mask_label.items(): if value == object_name: target_mask_id = int(key) break if target_mask_id is None: raise Exception("Object name not found in the mask labels") target_mask = (old_mask == target_mask_id) return target_mask @staticmethod def convert_mesh(mesh): object_model_scale = [0.001, 0.001, 0.001] mesh.apply_scale(object_model_scale) return mesh @staticmethod def convert_gt_pose(scene_data, object_name, cam_pose): pos = scene_data[object_name]["position"] quat = scene_data[object_name]["rotation"] rot = R.from_quat(quat).as_matrix() obj_pose = np.eye(4) obj_pose[:3, :3] = rot obj_pose[:3, 3] = pos obj_cam_pose = np.linalg.inv(cam_pose.cpu()) @ obj_pose return np.asarray(obj_cam_pose) def convert_data(object_name, scene_data, cam_pose, rgb, depth, seg, seg_labels, camera_params): rgb = OmniDataConverter.convert_rgb(rgb) depth = OmniDataConverter.convert_depth(depth) mask = OmniDataConverter.convert_mask(seg, seg_labels, object_name) K = np.array([[camera_params["fx"], 0, camera_params["cx"]], [0, camera_params["fy"], camera_params["cy"]], [0, 0, 1]]) gt_pose = OmniDataConverter.convert_gt_pose(scene_data, object_name,cam_pose) return K, rgb, depth, mask, gt_pose, object_name def get_mesh(obj_name, source): data_dir = ConfigManager.get("datasets", "general", "data_dir") class_name = obj_name[:-4] mesh_path = os.path.join(data_dir,source,"objects",class_name, obj_name,"Scan","Simp.obj") # TODO: to be changed mesh = trimesh.load(mesh_path) mesh = OmniDataConverter.convert_mesh(mesh) return mesh def get_scene_data(scene_name,source, data_type,): data_dir = ConfigManager.get("datasets", "general", "data_dir") scene_data_path = os.path.join(data_dir,source,data_type,scene_name,"scene.pickle") with open(scene_data_path, "rb") as f: scene_data = pickle.load(f) return scene_data def get_transformed_mat(src_mat, delta_rot,target_center_w): src_rot = src_mat[:3, :3] dst_rot = src_rot @ delta_rot.T dst_mat = torch.eye(4).to(dst_rot.device) dst_mat[:3, :3] = dst_rot distance = torch.norm(target_center_w - src_mat[:3, 3]) z_axis_camera = dst_rot[:3, 2].reshape(-1) new_camera_position_w = target_center_w - distance * z_axis_camera dst_mat[:3, 3] = new_camera_position_w return dst_mat def get_score_from_data_list(data_list,source): avg_adds = 0 data_dict = {} for K, rgb, depth, mask, gt_pose, object_name in data_list: if object_name not in data_dict: mesh = get_mesh(object_name,source) data_dict[object_name] = {"K": K, "mesh": mesh, "rgb": [rgb], "depth": [depth], "mask": [mask], "gt_pose": [gt_pose]} else: data_dict[object_name]["rgb"].append(rgb) data_dict[object_name]["depth"].append(depth) data_dict[object_name]["mask"].append(mask) data_dict[object_name]["gt_pose"].append(gt_pose) for object_name in data_dict: K = data_dict[object_name]["K"] mesh = data_dict[object_name]["mesh"] rgb_batch = np.stack(data_dict[object_name]["rgb"]) depth_batch = np.stack(data_dict[object_name]["depth"]) mask_batch = np.stack(data_dict[object_name]["mask"]) gt_pose_batch = np.stack(data_dict[object_name]["gt_pose"]) _, results_batch = ViewUtil.get_object_pose_batch(K, mesh, rgb_batch, depth_batch, mask_batch, gt_pose_batch,11111) # TODO: port number should be variable print("object_name:",object_name, "length:",len(gt_pose_batch),len(results_batch)) for result in results_batch: avg_adds += result["ADD-S"] avg_adds /= len(data_list) return avg_adds async def async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list, all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port): all_src_view_data_list = [] all_part_gt_dst_view_data_list = [] all_full_gt_dst_view_data_list = [] all_est_dst_view_data_list = [] with ThreadPoolExecutor() as executor: loop = asyncio.get_event_loop() for i in tqdm(range(total), desc="----Processing items", ncols=100): src_mat = all_src_mat_list[i] part_gt_dst_mat = all_part_gt_dst_mat_list[i] full_gt_dst_mat = all_full_gt_dst_mat_list[i] est_dst_mat = all_est_dst_mat_list[i] source = all_source_list[i] data_type = all_data_type_list[i] scene_name = all_scene_name_list[i] obj_name = all_object_name_list[i] src_view_future = loop.run_in_executor(executor, ViewUtil.get_view, src_mat, source, data_type, scene_name, web_server_port) part_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, part_gt_dst_mat, source, data_type, scene_name, web_server_port + 1) full_gt_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, full_gt_dst_mat, source, data_type, scene_name, web_server_port + 2) est_dst_view_future = loop.run_in_executor(executor, ViewUtil.get_view, est_dst_mat, source, data_type, scene_name, web_server_port + 3) src_view_data, part_gt_dst_view_data, full_gt_dst_view_data, est_dst_view_data = await asyncio.gather( src_view_future, part_gt_dst_view_future, full_gt_dst_view_future, est_dst_view_future ) scene_data = get_scene_data(scene_name,source, data_type) all_src_view_data_list.append(convert_data(obj_name, scene_data, src_mat, *src_view_data)) all_part_gt_dst_view_data_list.append(convert_data(obj_name, scene_data, part_gt_dst_mat, *part_gt_dst_view_data)) all_full_gt_dst_view_data_list.append(convert_data(obj_name, scene_data, full_gt_dst_mat, *full_gt_dst_view_data)) all_est_dst_view_data_list.append(convert_data(obj_name, scene_data, est_dst_mat, *est_dst_view_data)) return (all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list) @stereotype.evaluation_method("object_pose_improvement") def evaluate(output_list, data_list): web_server_port = ConfigManager.get("settings", "experiment", "web_api", "port") all_src_mat_list = [] all_part_gt_dst_mat_list = [] all_full_gt_dst_mat_list = [] all_est_dst_mat_list = [] all_scene_name_list = [] all_object_name_list = [] all_source_list = [] all_data_type_list = [] all_target_center_w_list = [] for output, data in zip(output_list, data_list): gt_delta_rot_6d_list = data["delta_rot_6d"] est_delta_rot_6d_list = output["estimated_delta_rot_6d"] src_mat_list = data["src_transform"] gt_mat_list = data["dst_transform"] scene_name_list = data["scene_name"] object_name_list = data["target_name"] target_pts_list = data["target_pts"] source_list = data["source"] data_type_list = data["data_type"] target_center_c_list = torch.mean(target_pts_list, axis=1) target_center_w_list = torch.bmm(src_mat_list[:,:3,:3], target_center_c_list.unsqueeze(2)).squeeze(2) + src_mat_list[:, :3, 3] gt_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_delta_rot_6d_list) est_delta_rot_mat_list = PoseUtil.rotation_6d_to_matrix_tensor_batch(est_delta_rot_6d_list) for i in range(len(scene_name_list)): src_mat = src_mat_list[i] target_center_w = target_center_w_list[i] gt_delta_rot_mat = gt_delta_rot_mat_list[i] est_delta_rot_mat = est_delta_rot_mat_list[i] part_gt_dst_mat = get_transformed_mat(src_mat, gt_delta_rot_mat,target_center_w) est_dst_mat = get_transformed_mat(src_mat, est_delta_rot_mat,target_center_w) all_src_mat_list.append(src_mat) all_part_gt_dst_mat_list.append(part_gt_dst_mat) all_full_gt_dst_mat_list.append(gt_mat_list[i]) all_est_dst_mat_list.append(est_dst_mat) all_scene_name_list.append(scene_name_list[i]) all_object_name_list.append(object_name_list[i]) all_source_list.append(source_list[i]) all_data_type_list.append(data_type_list[i]) all_target_center_w_list.append(target_center_w) all_src_view_data_list = [] all_part_gt_dst_view_data_list = [] all_full_gt_dst_view_data_list = [] all_est_dst_view_data_list = [] source = all_source_list[0] total = len(all_src_mat_list) loop = asyncio.get_event_loop() all_view_data_list = loop.run_until_complete(async_get_view(total, all_src_mat_list, all_part_gt_dst_mat_list, all_full_gt_dst_mat_list, all_est_dst_mat_list, all_source_list, all_data_type_list, all_scene_name_list, all_object_name_list, web_server_port)) all_src_view_data_list, all_part_gt_dst_view_data_list, all_full_gt_dst_view_data_list, all_est_dst_view_data_list = all_view_data_list src_score = get_score_from_data_list(all_src_view_data_list,source) part_gt_dst_score = get_score_from_data_list(all_part_gt_dst_view_data_list,source) full_gt_dst_score = get_score_from_data_list(all_full_gt_dst_view_data_list,source) est_dst_score = get_score_from_data_list(all_est_dst_view_data_list,source) score_improvement = est_dst_score - src_score score_diff_to_full_gt = full_gt_dst_score - est_dst_score score_diff_to_part_gt = part_gt_dst_score - est_dst_score look_at_center_score_diff = full_gt_dst_score - part_gt_dst_score results = { "scalars": { "object_pose_score": { "src": src_score, "part_gt_dst": part_gt_dst_score, "full_gt_dst": full_gt_dst_score, "est_dst": est_dst_score, }, "object_pose_score_improvement": score_improvement, "object_pose_score_diff_to_full_gt": score_diff_to_full_gt, "object_pose_score_diff_to_part_gt": score_diff_to_part_gt, "object_pose_look_at_center_score_diff": look_at_center_score_diff, } } return results if __name__ == "__main__": pass