diff --git a/beans/predict_result.py b/beans/predict_result.py new file mode 100644 index 0000000..db99270 --- /dev/null +++ b/beans/predict_result.py @@ -0,0 +1,162 @@ +import numpy as np +from sklearn.cluster import DBSCAN + +class PredictResult: + def __init__(self, raw_predict_result, input_pts=None, cluster_params=dict(eps=0.5, min_samples=2)): + self.input_pts = input_pts + self.cluster_params = cluster_params + self.sampled_9d_pose = raw_predict_result + self.sampled_matrix_pose = self.get_sampled_matrix_pose() + self.distance_matrix = self.calculate_distance_matrix() + self.clusters = self.get_cluster_result() + self.candidate_matrix_poses = self.get_candidate_poses() + self.candidate_9d_poses = [np.concatenate((self.matrix_to_rotation_6d_numpy(matrix[:3,:3]), matrix[:3,3].reshape(-1,)), axis=-1) for matrix in self.candidate_matrix_poses] + self.cluster_num = len(self.clusters) + + @staticmethod + def rotation_6d_to_matrix_numpy(d6): + a1, a2 = d6[:3], d6[3:] + b1 = a1 / np.linalg.norm(a1) + b2 = a2 - np.dot(b1, a2) * b1 + b2 = b2 / np.linalg.norm(b2) + b3 = np.cross(b1, b2) + return np.stack((b1, b2, b3), axis=-2) + + @staticmethod + def matrix_to_rotation_6d_numpy(matrix): + return np.copy(matrix[:2, :]).reshape((6,)) + + def __str__(self): + info = "Predict Result:\n" + info += f" Predicted pose number: {len(self.sampled_9d_pose)}\n" + info += f" Cluster number: {self.cluster_num}\n" + for i, cluster in enumerate(self.clusters): + info += f" - Cluster {i} size: {len(cluster)}\n" + max_distance = np.max(self.distance_matrix[self.distance_matrix != 0]) + min_distance = np.min(self.distance_matrix[self.distance_matrix != 0]) + info += f" Max distance: {max_distance}\n" + info += f" Min distance: {min_distance}\n" + return info + + def get_sampled_matrix_pose(self): + sampled_matrix_pose = [] + for pose in self.sampled_9d_pose: + rotation = pose[:6] + translation = pose[6:] + pose = self.rotation_6d_to_matrix_numpy(rotation) + pose = np.concatenate((pose, translation.reshape(-1, 1)), axis=-1) + pose = np.concatenate((pose, np.array([[0, 0, 0, 1]])), axis=-2) + sampled_matrix_pose.append(pose) + return np.array(sampled_matrix_pose) + + def rotation_distance(self, R1, R2): + R = np.dot(R1.T, R2) + trace = np.trace(R) + angle = np.arccos(np.clip((trace - 1) / 2, -1, 1)) + return angle + + def calculate_distance_matrix(self): + n = len(self.sampled_matrix_pose) + dist_matrix = np.zeros((n, n)) + for i in range(n): + for j in range(n): + dist_matrix[i, j] = self.rotation_distance(self.sampled_matrix_pose[i][:3, :3], self.sampled_matrix_pose[j][:3, :3]) + return dist_matrix + + def cluster_rotations(self): + clustering = DBSCAN(eps=self.cluster_params['eps'], min_samples=self.cluster_params['min_samples'], metric='precomputed') + labels = clustering.fit_predict(self.distance_matrix) + return labels + + def get_cluster_result(self): + labels = self.cluster_rotations() + cluster_num = len(set(labels)) - (1 if -1 in labels else 0) + clusters = [] + for _ in range(cluster_num): + clusters.append([]) + for matrix_pose, label in zip(self.sampled_matrix_pose, labels): + if label != -1: + clusters[label].append(matrix_pose) + clusters.sort(key=len, reverse=True) + return clusters + + def get_center_matrix_pose_from_cluster(self, cluster): + min_total_distance = float('inf') + center_matrix_pose = None + + for matrix_pose in cluster: + total_distance = 0 + for other_matrix_pose in cluster: + rot_distance = self.rotation_distance(matrix_pose[:3, :3], other_matrix_pose[:3, :3]) + total_distance += rot_distance + + if total_distance < min_total_distance: + min_total_distance = total_distance + center_matrix_pose = matrix_pose + + return center_matrix_pose + + def get_candidate_poses(self): + candidate_poses = [] + for cluster in self.clusters: + candidate_poses.append(self.get_center_matrix_pose_from_cluster(cluster)) + return candidate_poses + + def visualize(self): + import plotly.graph_objects as go + fig = go.Figure() + if self.input_pts is not None: + fig.add_trace(go.Scatter3d( + x=self.input_pts[:, 0], y=self.input_pts[:, 1], z=self.input_pts[:, 2], + mode='markers', marker=dict(size=1, color='gray', opacity=0.5), name='Input Points' + )) + colors = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance', + 'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg'] + for i, cluster in enumerate(self.clusters): + color = colors[i] + candidate_pose = self.candidate_matrix_poses[i] + origin_candidate = candidate_pose[:3, 3] + z_axis_candidate = candidate_pose[:3, 2] + for pose in cluster: + origin = pose[:3, 3] + z_axis = pose[:3, 2] + fig.add_trace(go.Cone( + x=[origin[0]], y=[origin[1]], z=[origin[2]], + u=[z_axis[0]], v=[z_axis[1]], w=[z_axis[2]], + colorscale=color, + sizemode="absolute", sizeref=0.05, anchor="tail", showscale=False + )) + fig.add_trace(go.Cone( + x=[origin_candidate[0]], y=[origin_candidate[1]], z=[origin_candidate[2]], + u=[z_axis_candidate[0]], v=[z_axis_candidate[1]], w=[z_axis_candidate[2]], + colorscale=color, + sizemode="absolute", sizeref=0.1, anchor="tail", showscale=False + )) + + fig.update_layout( + title="Clustered Poses and Input Points", + scene=dict( + xaxis_title='X', + yaxis_title='Y', + zaxis_title='Z' + ), + margin=dict(l=0, r=0, b=0, t=40), + scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)) + ) + + fig.show() + + + +if __name__ == "__main__": + step = 0 + raw_predict_result = np.load(f"inference_result_pack/inference_result_pack/{step}/all_pred_pose_9d.npy") + input_pts = np.loadtxt(f"inference_result_pack/inference_result_pack/{step}/input_pts.txt") + print(raw_predict_result.shape) + predict_result = PredictResult(raw_predict_result, input_pts, cluster_params=dict(eps=0.25, min_samples=3)) + print(predict_result) + print(len(predict_result.candidate_matrix_poses)) + print(predict_result.distance_matrix) + #import ipdb; ipdb.set_trace() + predict_result.visualize() + diff --git a/configs/local/inference_config.yaml b/configs/local/inference_config.yaml index 9092cde..70be656 100644 --- a/configs/local/inference_config.yaml +++ b/configs/local/inference_config.yaml @@ -6,16 +6,16 @@ runner: cuda_visible_devices: "0,1,2,3,4,5,6,7" experiment: - name: train_ab_global_only + name: train_ab_global_only_p++_wp root_dir: "experiments" - epoch: -1 # -1 stands for last epoch + epoch: 922 # -1 stands for last epoch test: dataset_list: - OmniObject3d_test blender_script_path: "/media/hofee/data/project/python/nbv_reconstruction/blender/data_renderer.py" - output_dir: " /media/hofee/data/data/temp" + output_dir: "/media/hofee/data/data/p++_wp_temp_cluster" pipeline: nbv_reconstruction_pipeline voxel_size: 0.003 min_new_area: 1.0 @@ -52,7 +52,7 @@ dataset: pipeline: nbv_reconstruction_pipeline: modules: - pts_encoder: pointnet_encoder + pts_encoder: pointnet++_encoder seq_encoder: transformer_seq_encoder pose_encoder: pose_encoder view_finder: gf_view_finder @@ -60,6 +60,9 @@ pipeline: global_scanned_feat: True module: + pointnet++_encoder: + in_dim: 3 + pointnet_encoder: in_dim: 3 out_dim: 1024 diff --git a/core/pipeline.py b/core/pipeline.py index a43d572..ae04d9e 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -75,6 +75,8 @@ class NBVReconstructionPipeline(nn.Module): def forward_test(self, data): main_feat = self.get_main_feat(data) + repeat_num = data.get("repeat_num", 100) + main_feat = main_feat.repeat(repeat_num, 1) estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view( main_feat ) diff --git a/runners/inferencer.py b/runners/inferencer.py index 658722e..f5cea35 100644 --- a/runners/inferencer.py +++ b/runners/inferencer.py @@ -4,6 +4,7 @@ from utils.render import RenderUtil from utils.pose import PoseUtil from utils.pts import PtsUtil from utils.reconstruction import ReconstructionUtil +from beans.predict_result import PredictResult import torch from tqdm import tqdm @@ -82,6 +83,7 @@ class Inferencer(Runner): data = test_set.__getitem__(i) scene_name = data["scene_name"] inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl") + if os.path.exists(inference_result_path): Log.info(f"Inference result already exists for scene: {scene_name}") continue @@ -142,88 +144,87 @@ class Inferencer(Runner): voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) output = self.pipeline(input_data) pred_pose_9d = output["pred_pose_9d"] - import ipdb; ipdb.set_trace() pred_pose = torch.eye(4, device=pred_pose_9d.device) - - pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] - pred_pose[:3,3] = pred_pose_9d[0,6:] - # ----- Debug ----- - - from utils.vis import visualizeUtil - import ipdb; ipdb.set_trace() - all_directions = [] - np.savetxt("input_pts.txt", input_data["combined_scanned_pts"].cpu().numpy()[0]) - for i in range(50): - output = self.pipeline(input_data) - pred_pose_9d = output["pred_pose_9d"] - cam_pos, sample_points = visualizeUtil.get_cam_pose_and_cam_axis(pred_pose_9d.cpu().numpy()[0], is_6d_pose=True) - all_directions.append(sample_points) - all_directions = np.array(all_directions) - reshape_all_directions = all_directions.reshape(-1, 3) - np.savetxt("all_directions.txt", reshape_all_directions) - # ----- ----- ----- - try: - new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + # # save pred_pose_9d ------ + # root = "/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction/temp_output_result" + # scene_dir = os.path.join(root, scene_name) + # if not os.path.exists(scene_dir): + # os.makedirs(scene_dir) + # pred_9d_path = os.path.join(scene_dir,f"pred_pose_9d_{len(pred_cr_seq)}.npy") + # pts_path = os.path.join(scene_dir,f"combined_scanned_pts_{len(pred_cr_seq)}.txt") + # np_combined_scanned_pts = input_data["combined_scanned_pts"][0].cpu().numpy() + # np.save(pred_9d_path, pred_pose_9d.cpu().numpy()) + # np.savetxt(pts_path, np_combined_scanned_pts) + # # ----- ----- ----- + pred_pose_9d_candidates = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses + for pred_pose_9d in pred_pose_9d_candidates: #import ipdb; ipdb.set_trace() - if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): - curr_overlap_area_threshold = overlap_area_threshold - else: - curr_overlap_area_threshold = overlap_area_threshold * 0.5 + pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0) + pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] + pred_pose[:3,3] = pred_pose_9d[0,6:] + try: + new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + #import ipdb; ipdb.set_trace() + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + curr_overlap_area_threshold = overlap_area_threshold + else: + curr_overlap_area_threshold = overlap_area_threshold * 0.5 - downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) - overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) - # if not overlap: - # Log.yellow("no overlap!") - # retry += 1 - # retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) - # continue + downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) + overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) + if not overlap: + Log.yellow("no overlap!") + retry += 1 + retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) + continue + + history_indices.append(new_scan_points_indices) + except Exception as e: + Log.error(f"Error in scene {scene_path}, {e}") + print("current pose: ", pred_pose) + print("curr_pred_cr: ", last_pred_cr) + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue - history_indices.append(new_scan_points_indices) - except Exception as e: - Log.error(f"Error in scene {scene_path}, {e}") - print("current pose: ", pred_pose) - print("curr_pred_cr: ", last_pred_cr) - retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) - retry += 1 - continue - - if new_target_pts.shape[0] == 0: - Log.red("no pts in new target") - retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) - retry += 1 - continue - - pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) - Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}") - if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: - print("max coverage rate reached!: ", pred_cr) + if new_target_pts.shape[0] == 0: + Log.red("no pts in new target") + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue - - - pred_cr_seq.append(pred_cr) - scanned_view_pts.append(new_target_pts) + pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) + Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}") + if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: + print("max coverage rate reached!: ", pred_cr) + - input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] - - combined_scanned_pts = np.vstack(scanned_view_pts) - voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) - random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) - input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) - - last_pred_cr = pred_cr - pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] - Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}") + pred_cr_seq.append(pred_cr) + scanned_view_pts.append(new_target_pts) + + input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] + + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) + random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) + input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) - if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2: - retry += 1 - retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) - Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") - elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2: - success += 1 - Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + + last_pred_cr = pred_cr + pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] + Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}") - last_pts_num = pts_num + if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2: + retry += 1 + retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) + Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2: + success += 1 + Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + + last_pts_num = pts_num + break input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()