Compare commits

..

No commits in common. "fa69f9f8794ddaac6c89ca86048ab73f37545e40" and "e315fd99ee7780e4771683f6ea970ab9c6688093" have entirely different histories.

3 changed files with 92 additions and 142 deletions

View File

@ -12,55 +12,41 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
self.config = config
self.module_config = config["modules"]
self.pts_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
)
self.pose_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
)
self.pose_n_num_seq_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]
)
self.view_finder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["view_finder"]
)
self.pts_num_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
)
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"])
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"])
self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"])
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"])
self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"])
self.eps = float(self.config["eps"])
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
def forward(self, data):
mode = data["mode"]
if mode == namespace.Mode.TRAIN:
return self.forward_train(data)
elif mode == namespace.Mode.TEST:
return self.forward_test(data)
else:
Log.error("Unknown mode: {}".format(mode), True)
def pertube_data(self, gt_delta_9d):
bs = gt_delta_9d.shape[0]
random_t = (
torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps
)
random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps
random_t = random_t.unsqueeze(-1)
mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t)
std = std.view(-1, 1)
z = torch.randn_like(gt_delta_9d)
perturbed_x = mu + z * std
target_score = -z * std / (std**2)
target_score = - z * std / (std ** 2)
return perturbed_x, random_t, target_score, std
def forward_train(self, data):
main_feat = self.get_main_feat(data)
""" get std """
''' get std '''
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
perturbed_x, random_t, target_score, std = self.pertube_data(
best_to_world_pose_9d_batch
)
perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch)
input_data = {
"sampled_pose": perturbed_x,
"t": random_t,
@ -70,69 +56,45 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
output = {
"estimated_score": estimated_score,
"target_score": target_score,
"std": std,
"std": std
}
return output
def forward_test(self, data):
def forward_test(self,data):
main_feat = self.get_main_feat(data)
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(
main_feat
)
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat)
result = {
"pred_pose_9d": estimated_delta_rot_9d,
"in_process_sample": in_process_sample,
"in_process_sample": in_process_sample
}
return result
def get_main_feat(self, data):
scanned_n_to_world_pose_9d_batch = data[
"scanned_n_to_world_pose_9d"
] # List(B): Tensor(S x 9)
scanned_pts_mask_batch = data[
"scanned_pts_mask"
] # Tensor(B x N)
scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d']
scanned_target_pts_num_batch = data['scanned_target_points_num']
device = next(self.parameters()).device
embedding_list_batch = []
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch):
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device)
scanned_target_pts_num = scanned_target_pts_num.to(device)
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3)
global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points(
combined_scanned_pts_batch, require_per_point_feat=True
) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl)
for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip(
scanned_n_to_world_pose_9d_batch,
scanned_pts_mask_batch,
perpoint_scanned_feat_batch,
):
scanned_target_pts_num = [] # List(S): Int
partial_feat_seq = []
seq_len = len(scanned_n_to_world_pose_9d)
for seq_idx in range(seq_len):
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
partial_feat = torch.mean(partial_perpoint_feat, dim=0)[0] # Tensor(Dl)
partial_feat_seq.append(partial_feat)
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])
scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.int32).to(device) # Tensor(S)
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9)
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp)
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn)
partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl)
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch)
combined_scanned_pts_batch = data['combined_scanned_pts']
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
if torch.isnan(main_feat).any():
Log.error("nan in main_feat", True)
return main_feat

View File

@ -122,6 +122,7 @@ class NBVReconstructionDataset(BaseDataset):
scanned_views_pts,
scanned_coverages_rate,
scanned_n_to_world_pose,
scanned_target_pts_num,
) = ([], [], [], [])
for view in scanned_views:
frame_idx = view[0]
@ -133,6 +134,7 @@ class NBVReconstructionDataset(BaseDataset):
target_point_cloud = (
DataLoadUtil.load_from_preprocessed_pts(view_path)
)
target_pts_num = target_point_cloud.shape[0]
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(
target_point_cloud, self.pts_num
)
@ -144,7 +146,7 @@ class NBVReconstructionDataset(BaseDataset):
n_to_world_trans = n_to_world_pose[:3, 3]
n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0)
scanned_n_to_world_pose.append(n_to_world_9d)
scanned_target_pts_num.append(target_pts_num)
nbv_idx, nbv_coverage_rate = nbv[0], nbv[1]
nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx)
@ -160,33 +162,30 @@ class NBVReconstructionDataset(BaseDataset):
)
combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0)
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
combined_scanned_views_pts, self.pts_num, require_idx=True
voxel_downsampled_combined_scanned_pts_np = (
PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002)
)
random_downsampled_combined_scanned_pts_np = (
PtsUtil.random_downsample_point_cloud(
voxel_downsampled_combined_scanned_pts_np, self.pts_num
)
)
combined_scanned_views_pts_mask = np.zeros(len(scanned_views_pts), dtype=np.uint8)
start_idx = 0
for i in range(len(scanned_views_pts)):
end_idx = start_idx + len(scanned_views_pts[i])
combined_scanned_views_pts_mask[start_idx:end_idx] = i
start_idx = end_idx
fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
data_item = {
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3)
"scanned_pts_mask": np.asarray(fps_downsampled_combined_scanned_pts_mask,dtype=np.uint8), # Ndarray(N), range(0, S)
"combined_scanned_pts": np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32), # Ndarray(N x 3)
"scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1)
"scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9)
"best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1)
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9)
"seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1)
"scene_name": scene_name, # String
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32),
"combined_scanned_pts": np.asarray(
random_downsampled_combined_scanned_pts_np, dtype=np.float32
),
"scanned_coverage_rate": scanned_coverages_rate,
"scanned_n_to_world_pose_9d": np.asarray(
scanned_n_to_world_pose, dtype=np.float32
),
"best_coverage_rate": nbv_coverage_rate,
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32),
"seq_max_coverage_rate": max_coverage_rate,
"scene_name": scene_name,
"scanned_target_points_num": np.asarray(
scanned_target_pts_num, dtype=np.int32
),
}
return data_item
@ -197,35 +196,33 @@ class NBVReconstructionDataset(BaseDataset):
def get_collate_fn(self):
def collate_fn(batch):
collate_data = {}
''' ------ Varialbe Length ------ '''
collate_data["scanned_pts"] = [
torch.tensor(item["scanned_pts"]) for item in batch
]
collate_data["scanned_n_to_world_pose_9d"] = [
torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch
]
''' ------ Fixed Length ------ '''
collate_data["scanned_target_points_num"] = [
torch.tensor(item["scanned_target_points_num"]) for item in batch
]
collate_data["best_to_world_pose_9d"] = torch.stack(
[torch.tensor(item["best_to_world_pose_9d"]) for item in batch]
)
collate_data["combined_scanned_pts"] = torch.stack(
[torch.tensor(item["combined_scanned_pts"]) for item in batch]
)
collate_data["scanned_pts_mask"] = torch.stack(
[torch.tensor(item["scanned_pts_mask"]) for item in batch]
)
if "first_frame_to_world" in batch[0]:
collate_data["first_frame_to_world"] = torch.stack(
[torch.tensor(item["first_frame_to_world"]) for item in batch]
)
for key in batch[0].keys():
if key not in [
"scanned_pts",
"scanned_pts_mask",
"scanned_n_to_world_pose_9d",
"best_to_world_pose_9d",
"first_frame_to_world",
"combined_scanned_pts",
"scanned_target_points_num",
]:
collate_data[key] = [item[key] for item in batch]
return collate_data

View File

@ -12,6 +12,12 @@ class PtsUtil:
downsampled_pc = o3d_pc.voxel_down_sample(voxel_size)
return np.asarray(downsampled_pc.points)
@staticmethod
def transform_point_cloud(points, pose_mat):
points_h = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
points_h = np.dot(pose_mat, points_h.T).T
return points_h[:, :3]
@staticmethod
def random_downsample_point_cloud(point_cloud, num_points, require_idx=False):
if point_cloud.shape[0] == 0:
@ -23,27 +29,6 @@ class PtsUtil:
return point_cloud[idx], idx
return point_cloud[idx]
@staticmethod
def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False):
N = point_cloud.shape[0]
mask = np.zeros(N, dtype=bool)
sampled_indices = np.zeros(num_points, dtype=int)
sampled_indices[0] = np.random.randint(0, N)
distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1)
for i in range(1, num_points):
farthest_index = np.argmax(distances)
sampled_indices[i] = farthest_index
mask[farthest_index] = True
new_distances = np.linalg.norm(point_cloud - point_cloud[farthest_index], axis=1)
distances = np.minimum(distances, new_distances)
sampled_points = point_cloud[sampled_indices]
if require_idx:
return sampled_points, sampled_indices
return sampled_points
@staticmethod
def random_downsample_point_cloud_tensor(point_cloud, num_points):
idx = torch.randint(0, len(point_cloud), (num_points,))
@ -55,12 +40,6 @@ class PtsUtil:
unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True)
return unique_voxels
@staticmethod
def transform_point_cloud(points, pose_mat):
points_h = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
points_h = np.dot(pose_mat, points_h.T).T
return points_h[:, :3]
@staticmethod
def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005, require_idx=False):
voxels_L, indices_L = PtsUtil.voxelize_points(point_cloud_L, voxel_size)
@ -77,6 +56,18 @@ class PtsUtil:
return overlapping_points, mask_L
return overlapping_points
@staticmethod
def new_filter_points(points, normals, cam_pose, theta=75, require_idx=False):
camera_axis = -cam_pose[:3, 2]
normals_normalized = normals / np.linalg.norm(normals, axis=1, keepdims=True)
cos_theta = np.dot(normals_normalized, camera_axis)
theta_rad = np.deg2rad(theta)
idx = cos_theta > np.cos(theta_rad)
filtered_points= points[idx]
if require_idx:
return filtered_points, idx
return filtered_points
@staticmethod
def filter_points(points, points_normals, cam_pose, voxel_size=0.002, theta=45, z_range=(0.2, 0.45)):