diff --git a/configs/local/inference_config.yaml b/configs/local/inference_config.yaml index ca8964e..b3bc6d5 100644 --- a/configs/local/inference_config.yaml +++ b/configs/local/inference_config.yaml @@ -1,12 +1,12 @@ runner: general: - seed: 1 + seed: 0 device: cuda cuda_visible_devices: "0,1,2,3,4,5,6,7" experiment: - name: overfit_ab_global_only + name: train_ab_global_only root_dir: "experiments" epoch: -1 # -1 stands for last epoch @@ -22,7 +22,7 @@ runner: dataset: OmniObject3d_train: root_dir: "/data/hofee/data/new_full_data" - model_dir: "/data/hofee/data/scaled_object_meshes" + model_dir: "/data/hofee/data/object_meshes_part3" source: seq_reconstruction_dataset split_file: "/data/hofee/data/sample.txt" type: test @@ -35,7 +35,7 @@ dataset: OmniObject3d_test: root_dir: "/data/hofee/data/new_full_data" - model_dir: "/data/hofee/data/scaled_object_meshes" + model_dir: "/data/hofee/data/object_meshes_part3" source: seq_reconstruction_dataset split_file: "/data/hofee/data/sample.txt" type: test diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index cd9de46..18a8163 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -9,7 +9,7 @@ runner: experiment: name: train_ab_global_only root_dir: "experiments" - use_checkpoint: False + use_checkpoint: True epoch: -1 # -1 stands for last epoch max_epochs: 5000 save_checkpoint_interval: 1 diff --git a/runners/inferencer.py b/runners/inferencer.py index 121cb98..e98c496 100644 --- a/runners/inferencer.py +++ b/runners/inferencer.py @@ -77,17 +77,17 @@ class Inferencer(Runner): status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) def predict_sequence(self, data, cr_increase_threshold=0, max_iter=50, max_retry=5): - scene_name = data["scene_name"][0] + scene_name = data["scene_name"] Log.info(f"Processing scene: {scene_name}") status_manager.set_status("inference", "inferencer", "scene", scene_name) ''' data for rendering ''' - scene_path = data["scene_path"][0] - O_to_L_pose = data["O_to_L_pose"][0] + scene_path = data["scene_path"] + O_to_L_pose = data["O_to_L_pose"] voxel_threshold = self.voxel_size filter_degree = 75 down_sampled_model_pts = data["gt_pts"] - import ipdb; ipdb.set_trace() + first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0] first_frame_to_world = np.eye(4) first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6]) @@ -95,14 +95,13 @@ class Inferencer(Runner): ''' data for inference ''' input_data = {} - input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device) + input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] input_data["mode"] = namespace.Mode.TEST input_pts_N = input_data["combined_scanned_pts"].shape[1] - - first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, down_sampled_model_pts, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + first_frame_target_pts, first_frame_target_normals = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) scanned_view_pts = [first_frame_target_pts] - last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) + last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) retry_duplication_pose = [] retry_no_pts_pose = [] @@ -118,7 +117,7 @@ class Inferencer(Runner): pred_pose[:3,3] = pred_pose_9d[0,6:] try: - new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True) + new_target_pts, new_target_normals = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) except Exception as e: Log.warning(f"Error in scene {scene_path}, {e}") print("current pose: ", pred_pose) @@ -127,12 +126,18 @@ class Inferencer(Runner): retry += 1 continue + if new_target_pts.shape[0] == 0: + print("no pts in new target") + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue - pred_cr = self.compute_coverage_rate(scanned_view_pts, new_target_pts_world, down_sampled_model_pts, threshold=voxel_threshold) - - print(pred_cr, last_pred_cr, " max: ", data["max_coverage_rate"]) - if pred_cr >= data["max_coverage_rate"]: - print("max coverage rate reached!") + pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) + print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"]) + if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: + print("max coverage rate reached!: ", pred_cr) + if new_added_pts_num < 10: + print("min added pts num reached!: ", new_added_pts_num) if pred_cr <= last_pred_cr + cr_increase_threshold: retry += 1 retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) @@ -140,17 +145,14 @@ class Inferencer(Runner): retry = 0 pred_cr_seq.append(pred_cr) - scanned_view_pts.append(new_target_pts_world) - down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_pts_world, input_pts_N) - new_pts_world_aug = np.hstack([down_sampled_new_pts_world, np.ones((down_sampled_new_pts_world.shape[0], 1))]) - new_pts = np.dot(np.linalg.inv(first_frame_to_world.cpu()), new_pts_world_aug.T).T[:,:3] - - new_pts_tensor = torch.tensor(new_pts, dtype=torch.float32).unsqueeze(0).to(self.device) + scanned_view_pts.append(new_target_pts) + down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_target_pts, input_pts_N) - input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)] + new_pts = down_sampled_new_pts_world input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] - combined_scanned_views_pts = np.concatenate(input_data["scanned_pts"][0].tolist(), axis=0) - voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) + + combined_scanned_pts = np.concatenate([input_data["combined_scanned_pts"][0].cpu().numpy(), new_pts], axis=0) + voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002) random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) diff --git a/utils/render.py b/utils/render.py index f2e95b0..639be89 100644 --- a/utils/render.py +++ b/utils/render.py @@ -4,11 +4,54 @@ import json import subprocess import tempfile import shutil +import numpy as np from utils.data_load import DataLoadUtil from utils.reconstruction import ReconstructionUtil from utils.pts import PtsUtil class RenderUtil: + target_mask_label = (0, 255, 0) + display_table_mask_label = (0, 0, 255) + random_downsample_N = 32768 + min_z = 0.2 + max_z = 0.5 + @staticmethod + def get_world_points_and_normal(depth, mask, normal, cam_intrinsic, cam_extrinsic, random_downsample_N): + z = depth[mask] + i, j = np.nonzero(mask) + x = (j - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (i - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + normal_camera = normal[mask].reshape(-1, 3) + sampled_target_points, idx = PtsUtil.random_downsample_point_cloud( + points_camera, random_downsample_N, require_idx=True + ) + if len(sampled_target_points) == 0: + return np.zeros((0, 3)), np.zeros((0, 3)) + sampled_normal_camera = normal_camera[idx] + + points_camera_aug = np.concatenate((sampled_target_points, np.ones((sampled_target_points.shape[0], 1))), axis=-1) + points_camera_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + + return points_camera_world, sampled_normal_camera + + @staticmethod + def get_world_points(depth, mask, cam_intrinsic, cam_extrinsic, random_downsample_N): + z = depth[mask] + i, j = np.nonzero(mask) + x = (j - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (i - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + sampled_target_points = PtsUtil.random_downsample_point_cloud( + points_camera, random_downsample_N + ) + points_camera_aug = np.concatenate((sampled_target_points, np.ones((sampled_target_points.shape[0], 1))), axis=-1) + points_camera_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + + return points_camera_world + @staticmethod def render_pts(cam_pose, scene_path, script_path, voxel_threshold=0.005, filter_degree=75, nO_to_nL_pose=None, require_full_scene=False): @@ -28,25 +71,50 @@ class RenderUtil: result = subprocess.run([ 'blender', '-b', '-P', script_path, '--', temp_dir ], capture_output=True, text=True) + if result.returncode != 0: print("Blender script failed:") print(result.stderr) return None path = os.path.join(temp_dir, "tmp") - point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True) - normals = DataLoadUtil.get_target_normals_world_from_path(path, binocular=True) - cam_params = DataLoadUtil.load_cam_info(path, binocular=True) - - filtered_point_cloud = PtsUtil.filter_points(point_cloud, normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=filter_degree) - full_scene_point_cloud = None - if require_full_scene: - depth_L, depth_R = DataLoadUtil.load_depth(path, cam_params['near_plane'], cam_params['far_plane'], binocular=True) - point_cloud_L = DataLoadUtil.get_point_cloud(depth_L, cam_params['cam_intrinsic'], cam_params['cam_to_world'])['points_world'] - point_cloud_R = DataLoadUtil.get_point_cloud(depth_R, cam_params['cam_intrinsic'], cam_params['cam_to_world_R'])['points_world'] - - point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, 65536) - point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536) - full_scene_point_cloud = PtsUtil.get_overlapping_points(point_cloud_L, point_cloud_R) + cam_info = DataLoadUtil.load_cam_info(path, binocular=True) + depth_L, depth_R = DataLoadUtil.load_depth( + path, cam_info["near_plane"], + cam_info["far_plane"], + binocular=True + ) + mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) + normal_L = DataLoadUtil.load_normal(path, binocular=True, left_only=True) + ''' target points ''' + mask_img_L = mask_L + mask_img_R = mask_R + + target_mask_img_L = (mask_L == RenderUtil.target_mask_label).all(axis=-1) + target_mask_img_R = (mask_R == RenderUtil.target_mask_label).all(axis=-1) - return filtered_point_cloud, full_scene_point_cloud \ No newline at end of file + sampled_target_points_L, sampled_target_normal_L = RenderUtil.get_world_points_and_normal(depth_L,target_mask_img_L,normal_L, cam_info["cam_intrinsic"], cam_info["cam_to_world"], RenderUtil.random_downsample_N) + sampled_target_points_R = RenderUtil.get_world_points(depth_R, target_mask_img_R, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"], RenderUtil.random_downsample_N ) + + + has_points = sampled_target_points_L.shape[0] > 0 and sampled_target_points_R.shape[0] > 0 + if has_points: + target_points, overlap_idx = PtsUtil.get_overlapping_points( + sampled_target_points_L, sampled_target_points_R, voxel_threshold, require_idx=True + ) + sampled_target_normal_L = sampled_target_normal_L[overlap_idx] + + if has_points: + has_points = target_points.shape[0] > 0 + + if has_points: + target_points, target_normals = PtsUtil.filter_points( + target_points, sampled_target_normal_L, cam_info["cam_to_world"], theta_limit = filter_degree, z_range=(RenderUtil.min_z, RenderUtil.max_z) + ) + + if not has_points: + target_points = np.zeros((0, 3)) + target_normals = np.zeros((0, 3)) + + #import ipdb; ipdb.set_trace() + return target_points, target_normals \ No newline at end of file