import os import time import trimesh import tempfile import subprocess import numpy as np from PytorchBoot.runners.runner import Runner from PytorchBoot.config import ConfigManager import PytorchBoot.stereotype as stereotype from PytorchBoot.utils.log_util import Log from PytorchBoot.status import status_manager from utils.control_util import ControlUtil from utils.communicate_util import CommunicateUtil from utils.pts_util import PtsUtil from utils.reconstruction_util import ReconstructionUtil from utils.preprocess_util import save_scene_data from utils.data_load import DataLoadUtil from utils.view_util import ViewUtil class PointCloud: def __init__(points, camera, name): pass class PointCloudGroup: def __init__(point_clouds, name): pass @stereotype.runner("temp") class CADCloseLoopOnlineRegStrategyRunner(Runner): def __init__(self, config_path: str): super().__init__(config_path) self.load_experiment("cad_strategy") self.generate_config = ConfigManager.get("runner", "generate") self.reconstruct_config = ConfigManager.get("runner", "reconstruct") self.output_dir = self.generate_config["output_dir"] self.model_dir = self.generate_config["model_dir"] self.object_name = self.generate_config["object_name"] self.blender_bin_path = self.generate_config["blender_bin_path"] self.generator_script_path = self.generate_config["generator_script_path"] self.voxel_size = self.generate_config["voxel_size"] self.max_shot_view_num = self.reconstruct_config["max_shot_view_num"] self.min_shot_new_pts_num = self.reconstruct_config["min_shot_new_pts_num"] self.min_coverage_increase = self.reconstruct_config["min_coverage_increase"] self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"] def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0): scan_pts = world_pts[world_pts[:, 2] < z_threshold] obj_pts = world_pts[world_pts[:, 2] >= z_threshold] return scan_pts, obj_pts def loop_scan(self, first_cam_to_real_world): view_pts_list = [] first_view_data = CommunicateUtil.get_view_data(init=True) ControlUtil.absolute_rotate_display_table(90) first_pts = ViewUtil.get_pts(first_view_data) first_real_world_pts = PtsUtil.transform_point_cloud( first_pts, first_cam_to_real_world ) _, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts( first_real_world_pts ) view_pts_list.append(first_splitted_real_world_pts) shot_num = 4 for i in range(shot_num-1): view_data = CommunicateUtil.get_view_data() if i != shot_num - 2: ControlUtil.absolute_rotate_display_table(90) time.sleep(0.5) if view_data is None: Log.error("No view data received") continue view_pts = ViewUtil.get_pts(view_data) real_world_pts = PtsUtil.transform_point_cloud( view_pts, first_cam_to_real_world ) _, splitted_real_world_pts = self.split_scan_pts_and_obj_pts( real_world_pts ) view_pts_list.append(splitted_real_world_pts) return view_pts_list def register(self): ControlUtil.connect_robot() """ init robot """ Log.info("start init") ControlUtil.init() first_cam_to_real_world = ControlUtil.get_pose() """ loop shooting """ Log.info("start loop shooting") view_pts_list = self.loop_scan(first_cam_to_real_world) """ register """ Log.info("start register") pts = np.vstack(view_pts_list) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) if not os.path.exists(os.path.join(self.output_dir, self.object_name)): os.makedirs(os.path.join(self.output_dir, self.object_name)) scene_dir = os.path.join(self.output_dir, self.object_name) model_path = os.path.join(self.model_dir, self.object_name, "mesh.obj") cad_model = trimesh.load(model_path) real_world_to_cad = PtsUtil.register(pts, cad_model) cad_to_real_world = np.linalg.inv(real_world_to_cad) Log.success("finish init and register") real_world_to_blender_world = np.eye(4) real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215]) cad_model_real_world: trimesh.Trimesh = cad_model.apply_transform( cad_to_real_world ) cad_model_real_world.export(os.path.join(scene_dir, "mesh.obj")) #downsampled_pts = PtsUtil.voxel_downsample_point_cloud(pts, self.voxel_size) np.savetxt(os.path.join(scene_dir, "pts_for_init_reg.txt"), pts) return cad_to_real_world def render_data(self): scene_dir = os.path.join(self.output_dir, self.object_name) result = subprocess.run( [ self.blender_bin_path, "-b", "-P", self.generator_script_path, "--", scene_dir, ], capture_output=True, text=True, ) print(result) def preprocess_data(self): save_scene_data(self.output_dir, self.object_name, file_type="npy") def get_scan_points_indices(self, scan_points, mask, object_mask_label= (0, 255, 0, 255), cam_intrinsic = None, cam_extrinsic = None): scan_points_homogeneous = np.hstack((scan_points, np.ones((scan_points.shape[0], 1)))) points_camera = np.dot(np.linalg.inv(cam_extrinsic), scan_points_homogeneous.T).T[:, :3] points_image_homogeneous = np.dot(cam_intrinsic, points_camera.T).T points_image_homogeneous /= points_image_homogeneous[:, 2:] pixel_x = points_image_homogeneous[:, 0].astype(int) pixel_y = points_image_homogeneous[:, 1].astype(int) h, w = mask.shape[:2] valid_indices = (pixel_x >= 0) & (pixel_x < w) & (pixel_y >= 0) & (pixel_y < h) mask_colors = mask[pixel_y[valid_indices], pixel_x[valid_indices]] selected_points_indices = np.where((mask_colors != object_mask_label).any(axis=-1))[0] selected_points_indices = np.where(valid_indices)[0][selected_points_indices] return selected_points_indices def run_one_model(self, model_name): scene_dir = os.path.join(self.output_dir, model_name) ControlUtil.connect_robot() """ init robot """ Log.info("start init") ControlUtil.init() first_cam_to_real_world = ControlUtil.get_pose() """ loop shooting """ Log.info("start loop shooting") view_pts_list = self.loop_scan(first_cam_to_real_world) """ register """ cad_path = os.path.join(scene_dir, "mesh.obj") cad_model = trimesh.load(cad_path) Log.info("start register") init_pts = np.vstack(view_pts_list) real_world_to_cad = PtsUtil.register(init_pts, cad_model) curr_cad_to_real_world = np.linalg.inv(real_world_to_cad) # np.savetxt(os.path.join("/home/yan20/nbv_rec/project/franka_control/debug", "pts_for_init_reg.txt"), init_pts) # debug_cad = cad_model.apply_transform(curr_cad_to_real_world) # debug_cad.export(os.path.join("/home/yan20/nbv_rec/project/franka_control/debug", "cad_for_init_reg.obj")) pts_dir = os.path.join(scene_dir, "pts") sample_view_pts_list = [] frame_num = len(os.listdir(pts_dir)) for frame_idx in range(frame_num): pts_path = os.path.join(scene_dir, "pts", f"{frame_idx}.npy") point_cloud = np.load(pts_path) if point_cloud.shape[0] != 0: sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud( point_cloud, self.voxel_size ) sample_view_pts_list.append(sampled_point_cloud) else: sample_view_pts_list.append(np.zeros((0, 3))) """ close-loop online registery strategy """ scanned_pts = PtsUtil.voxel_downsample_point_cloud(init_pts, voxel_size=self.voxel_size) shot_pts_list = [] last_coverage = 0 Log.info("start close-loop control") cnt = 0 mask_list = [] cam_to_cad_list = [] cam_R_to_cad_list = [] shot_view_idx_list = [] scan_points_path = os.path.join(self.output_dir, self.object_name, "scan_points.txt") display_table_scan_points = np.loadtxt(scan_points_path) for i in range(frame_num): path = DataLoadUtil.get_path(self.output_dir, self.object_name, i) mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) mask_list.append((mask_L, mask_R)) cam_info = DataLoadUtil.load_cam_info(path, binocular=True) cam_to_cad = cam_info["cam_to_world"] cam_to_cad_list.append(cam_to_cad) cam_R_to_cad = cam_info["cam_to_world_R"] cam_R_to_cad_list.append(cam_R_to_cad) selected_view = [] while True: import ipdb; ipdb.set_trace() history_indices = [] scan_points_idx_list = [] for i in range(frame_num): cam_to_cad = cam_to_cad_list[i] cam_R_to_cad = cam_R_to_cad_list[i] curr_cam_L_to_world = curr_cad_to_real_world @ cam_to_cad curr_cam_R_to_world = curr_cad_to_real_world @ cam_R_to_cad scan_points_indices_L = self.get_scan_points_indices(display_table_scan_points, mask_list[i][0], cam_intrinsic=cam_info["cam_intrinsic"], cam_extrinsic=curr_cam_L_to_world) scan_points_indices_R = self.get_scan_points_indices(display_table_scan_points, mask_list[i][1], cam_intrinsic=cam_info["cam_intrinsic"], cam_extrinsic=curr_cam_R_to_world) scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R) scan_points_idx_list.append(scan_points_indices) for shot_view_idx in shot_view_idx_list: history_indices.append(scan_points_idx_list[shot_view_idx]) cad_scanned_pts = PtsUtil.transform_point_cloud(scanned_pts, np.linalg.inv(curr_cad_to_real_world)) next_best_view, next_best_coverage, next_best_covered_num = ( ReconstructionUtil.compute_next_best_view_with_overlap( cad_scanned_pts, sample_view_pts_list, selected_view, history_indices, scan_points_idx_list, threshold=self.voxel_size, overlap_area_threshold=25, scan_points_threshold=self.scan_points_threshold, ) ) if next_best_view is None: Log.warning("No next best view found") selected_view.append(next_best_view) nbv_path = DataLoadUtil.get_path(self.output_dir, self.object_name, next_best_view) nbv_cam_info = DataLoadUtil.load_cam_info(nbv_path, binocular=True) nbv_cam_to_cad = nbv_cam_info["cam_to_world_O"] nbv_cam_to_world = curr_cad_to_real_world @ nbv_cam_to_cad target_camera_pose = nbv_cam_to_world @ ControlUtil.CAMERA_CORRECTION ControlUtil.move_to(target_camera_pose) ''' get world pts ''' time.sleep(0.5) view_data = CommunicateUtil.get_view_data() if view_data is None: Log.error("No view data received") continue shot_view_idx_list.append(next_best_view) cam_shot_pts = ViewUtil.get_pts(view_data) world_shot_pts = PtsUtil.transform_point_cloud( cam_shot_pts, first_cam_to_real_world ) _, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts( world_shot_pts ) shot_pts_list.append(world_splitted_shot_pts) debug_dir = os.path.join(scene_dir, "debug") if not os.path.exists(debug_dir): os.makedirs(debug_dir) last_scanned_pts_num = scanned_pts.shape[0] import ipdb;ipdb.set_trace() new_scanned_pts = PtsUtil.voxel_downsample_point_cloud( np.vstack([scanned_pts, world_splitted_shot_pts]), self.voxel_size ) # last_real_world_to_cad = real_world_to_cad # real_world_to_cad = PtsUtil.register_fine(new_scanned_pts, cad_model) # # rot distance of two rotation matrix # rot_dist = np.arccos( # (np.trace(real_world_to_cad[:3, :3].T @ last_real_world_to_cad[:3, :3]) - 1) / 2 # ) # print(f"-----rot dist: {rot_dist}") curr_cad_to_real_world = np.linalg.inv(real_world_to_cad) cad_splitted_shot_pts = PtsUtil.transform_point_cloud(world_splitted_shot_pts, real_world_to_cad) np.savetxt(os.path.join(debug_dir, f"shot_pts_{cnt}.txt"), world_splitted_shot_pts) np.savetxt(os.path.join(debug_dir, f"render_pts_{cnt}.txt"), sample_view_pts_list[next_best_view]) np.savetxt(os.path.join(debug_dir, f"reg_scanned_pts_{cnt}.txt"), new_scanned_pts) cad_pts = cad_model.vertices world_cad_pts = PtsUtil.transform_point_cloud(cad_pts, curr_cad_to_real_world) np.savetxt(os.path.join(debug_dir, f"world_cad_pts_{cnt}.txt"), world_cad_pts) #import ipdb; ipdb.set_trace() new_scanned_pts_num = new_scanned_pts.shape[0] scanned_pts = new_scanned_pts Log.info( f"Next Best cover pts: {next_best_covered_num}, Best coverage: {next_best_coverage}" ) coverage_rate_increase = next_best_coverage - last_coverage if coverage_rate_increase < self.min_coverage_increase: Log.info(f"Coverage rate = {coverage_rate_increase} < {self.min_coverage_increase}, stop scanning") # break last_coverage = next_best_coverage new_added_pts_num = new_scanned_pts_num - last_scanned_pts_num if new_added_pts_num < self.min_shot_new_pts_num: Log.info(f"New added pts num = {new_added_pts_num} < {self.min_shot_new_pts_num}") #ipdb.set_trace() if len(shot_pts_list) >= self.max_shot_view_num: Log.info(f"Scanned view num = {len(shot_pts_list)} >= {self.max_shot_view_num}, stop scanning") #break cnt += 1 Log.success("[Part 4/4] finish close-loop control") def run(self): self.run_one_model(self.object_name) # ---------------------------- test ---------------------------- # if __name__ == "__main__": model_path = r"/home/yan20/nbv_rec/data/models/workpiece_1/mesh.obj" model = trimesh.load(model_path)