diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 721f327..b12e2ef 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -28,7 +28,7 @@ runner: #- OmniObject3d_test - OmniObject3d_val - pipeline: nbv_reconstruction_global_pts_n_num_pipeline + pipeline: nbv_reconstruction_pipeline dataset: OmniObject3d_train: diff --git a/core/pipeline.py b/core/pipeline.py index 797fd87..ec05e32 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -92,7 +92,9 @@ class NBVReconstructionPipeline(nn.Module): scanned_n_to_world_pose_9d_batch = data[ "scanned_n_to_world_pose_9d" ] # List(B): Tensor(S x 9) - + scanned_pts_batch = data[ + "scanned_pts" + ] device = next(self.parameters()).device embedding_list_batch = [] @@ -102,11 +104,13 @@ class NBVReconstructionPipeline(nn.Module): combined_scanned_pts_batch, require_per_point_feat=False ) # global_scanned_feat: Tensor(B x Dg) - for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch: + for scanned_n_to_world_pose_9d, scanned_pts in zip(scanned_n_to_world_pose_9d_batch, scanned_pts_batch): scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9) + scanned_pts = scanned_pts.to(device) # Tensor(S x N x 3) pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) - seq_embedding = pose_feat_seq - embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp)) + pts_feat_seq = self.pts_encoder.encode_points(scanned_pts, require_per_point_feat=False) # Tensor(S x Dl) + seq_embedding = torch.cat([pose_feat_seq, pts_feat_seq], dim=-1) # Tensor(S x (Dp+Dl)) + embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dl)) seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))