add inference

This commit is contained in:
2024-09-19 00:14:26 +08:00
parent 9ec3a00fd4
commit 935069d68c
10 changed files with 302 additions and 139 deletions

View File

@@ -27,7 +27,7 @@ class NBVReconstructionDataset(BaseDataset):
self.pts_num = config["pts_num"]
self.type = config["type"]
self.cache = config["cache"]
self.cache = config.get("cache")
if self.type == namespace.Mode.TEST:
self.model_dir = config["model_dir"]
self.filter_degree = config["filter_degree"]
@@ -105,7 +105,10 @@ class NBVReconstructionDataset(BaseDataset):
nR_to_world_pose = cam_info["cam_to_world_R"]
n_to_1_pose = np.dot(np.linalg.inv(first_frame_to_world), n_to_world_pose)
nR_to_1_pose = np.dot(np.linalg.inv(first_frame_to_world), nR_to_world_pose)
cached_data = self.load_from_cache(scene_name, first_frame_idx, frame_idx)
cached_data = None
if self.cache:
cached_data = self.load_from_cache(scene_name, first_frame_idx, frame_idx)
if cached_data is None:
depth_L, depth_R = DataLoadUtil.load_depth(view_path, cam_info['near_plane'], cam_info['far_plane'], binocular=True)
@@ -116,7 +119,8 @@ class NBVReconstructionDataset(BaseDataset):
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536)
overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R)
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(overlap_points, self.pts_num)
self.save_to_cache(scene_name, first_frame_idx, frame_idx, downsampled_target_point_cloud)
if self.cache:
self.save_to_cache(scene_name, first_frame_idx, frame_idx, downsampled_target_point_cloud)
else:
downsampled_target_point_cloud = cached_data