import os import numpy as np import requests from PytorchBoot.runners import Runner from PytorchBoot.config import ConfigManager import PytorchBoot.stereotype as stereotype from PytorchBoot.utils.log_util import Log from utils.pose_util import PoseUtil from utils.control_util import ControlUtil from utils.communicate_util import CommunicateUtil from utils.pts_util import PtsUtil from utils.view_util import ViewUtil from scipy.spatial.transform import Rotation as R @stereotype.runner("inference_runner") class InferenceRunner(Runner): def __init__(self, config_path: str): super().__init__(config_path) self.load_experiment("inference") self.inference_config = ConfigManager.get("runner", "inference") self.server_url = self.inference_config["server_url"] self.max_iter = self.inference_config["max_iter"] self.max_fail = self.inference_config["max_fail"] self.max_no_new_pts = self.inference_config["max_no_new_pts"] self.min_delta_pts_num = self.inference_config["min_delta_pts_num"] def check_stop(self, cnt, fail, no_new_pts): if cnt > self.max_iter: return True if fail > self.max_fail: return True if no_new_pts > self.max_no_new_pts: return True return False def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0.005): scan_pts = world_pts[world_pts[:, 2] < z_threshold] obj_pts = world_pts[world_pts[:, 2] >= z_threshold] return scan_pts, obj_pts def run(self): ControlUtil.connect_robot() ControlUtil.init() scanned_pts_list = [] scanned_n_to_world_pose = [] cnt = 0 fail = 0 no_new_pts = 0 view_data = CommunicateUtil.get_view_data(init=True) first_cam_to_real_world = ControlUtil.get_pose() if view_data is None: Log.error("No view data received") fail += 1 return cam_shot_pts = ViewUtil.get_pts(view_data) # ########################################### DEBUG ########################################### # sensor_pts = PtsUtil.transform_point_cloud(cam_shot_pts, np.linalg.inv(ControlUtil.CAMERA_TO_LEFT_CAMERA)) # np.savetxt('/home/yan20/Downloads/left_pts_0.txt', cam_shot_pts) # np.savetxt('/home/yan20/Downloads/sensor_pts_0.txt', sensor_pts) # ############################################################################################# world_shot_pts = PtsUtil.transform_point_cloud( cam_shot_pts, first_cam_to_real_world ) #import ipdb; ipdb.set_trace() _, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts( world_shot_pts ) curr_pts = world_splitted_shot_pts curr_pose = first_cam_to_real_world curr_pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(curr_pose[:3,:3]) curr_pose_9d = np.concatenate([curr_pose_6d, curr_pose[:3, 3]], axis=0) scanned_pts_list.append(curr_pts.tolist()) scanned_n_to_world_pose.append(curr_pose_9d.tolist()) combined_pts = np.concatenate(scanned_pts_list, axis=0) downsampled_combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, 0.003) last_downsampled_combined_pts_num = downsampled_combined_pts.shape[0] Log.info(f"First downsampled combined pts: {last_downsampled_combined_pts_num}") ####################################### DEBUG ####################################### # scan_count = 0 # save_path = "/home/yan20/Downloads/pts" # if not os.path.exists(save_path): # os.makedirs(save_path) ##################################################################################### while not self.check_stop(cnt, fail, no_new_pts): data = { "scanned_pts": scanned_pts_list, "scanned_n_to_world_pose_9d": scanned_n_to_world_pose } # pts = np.array(data['scanned_pts'][-1]) # np.savetxt(f'{save_path}/pts_{scan_count}.txt', pts) # scan_count += 1 response = requests.post(self.server_url, json=data) result = response.json() pred_pose_9d = np.array(result["pred_pose_9d"]) pred_rot_6d = pred_pose_9d[0, :6] pred_trans = pred_pose_9d[0, 6:] pred_rot_mat = PoseUtil.rotation_6d_to_matrix_numpy(pred_rot_6d) pred_pose = np.eye(4) pred_pose[:3, :3] = pred_rot_mat pred_pose[:3, 3] = pred_trans target_camera_pose = pred_pose @ ControlUtil.CAMERA_CORRECTION ControlUtil.move_to(target_camera_pose) cnt += 1 view_data = CommunicateUtil.get_view_data() if view_data is None: Log.error("No view data received") fail += 1 continue cam_shot_pts = ViewUtil.get_pts(view_data) left_cam_to_first_left_cam = ViewUtil.get_camera_pose(view_data) curr_pose = first_cam_to_real_world @ left_cam_to_first_left_cam @ np.linalg.inv(ControlUtil.CAMERA_CORRECTION) # curr_pose = pred_pose # curr_pose = first_cam_to_real_world @ ViewUtil.get_camera_pose(view_data) print('pred_pose:', pred_pose) print('curr_pose:', curr_pose) ##################################### DEBUG ##################################### # print(curr_pose) # rot = R.from_matrix(curr_pose[:3, :3]) # quat_xyzw = rot.as_quat() # translation = curr_pose[:3, 3] # print(quat_xyzw, translation) # # from ipdb import set_trace; set_trace() ################################################################################# world_shot_pts = PtsUtil.transform_point_cloud( cam_shot_pts, first_cam_to_real_world ) _, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts( world_shot_pts ) curr_pts = world_splitted_shot_pts import ipdb; ipdb.set_trace() from utils.vis import visualizeUtil visualizeUtil.visualize_pts_and_camera(world_splitted_shot_pts,pred_pose) curr_pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(curr_pose[:3,:3]) curr_pose_9d = np.concatenate([curr_pose_6d, curr_pose[:3, 3]], axis=0) scanned_pts_list.append(curr_pts.tolist()) scanned_n_to_world_pose.append(curr_pose_9d.tolist()) combined_pts = np.concatenate(scanned_pts_list, axis=0) downsampled_combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, 0.003) curr_downsampled_combined_pts_num = downsampled_combined_pts.shape[0] Log.info(f"Downsampled combined pts: {curr_downsampled_combined_pts_num}") if curr_downsampled_combined_pts_num < last_downsampled_combined_pts_num + self.min_delta_pts_num: no_new_pts += 1 Log.info(f"No new points, cnt: {cnt}, fail: {fail}, no_new_pts: {no_new_pts}") continue Log.success("Inference finished") # self.save_inference_result(scanned_pts_list, downsampled_combined_pts) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) self.inference_result_dir = os.path.join(self.experiment_path, "inference_result") os.makedirs(self.inference_result_dir) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.inference_result_dir = os.path.join(self.experiment_path, "inference_result") def save_inference_result(self, scanned_pts_list, downsampled_combined_pts): import time dir_name = time.strftime("inference_result_%Y_%m_%d_%Hh%Mm%Ss", time.localtime()) dir_path = os.path.join(self.inference_result_dir, dir_name) for i in range(len(scanned_pts_list)): np.savetxt(os.path.join(dir_path, f"{i}.txt"), np.array(scanned_pts_list[i])) np.savetxt(os.path.join(dir_path, "downsampled_combined_pts.txt"), np.array(downsampled_combined_pts)) Log.success("Inference result saved")