diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 2bf558d..190e6e3 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -116,16 +116,16 @@ module: feature_transform: False transformer_seq_encoder: - embed_dim: 384 + embed_dim: 256 num_heads: 4 ffn_dim: 256 num_layers: 3 - output_dim: 2048 + output_dim: 1024 gf_view_finder: t_feat_dim: 128 pose_feat_dim: 256 - main_feat_dim: 3072 + main_feat_dim: 2048 regression_head: Rx_Ry_and_T pose_mode: rot_matrix per_point_feature: False diff --git a/core/pipeline.py b/core/pipeline.py index 164253c..797fd87 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -7,10 +7,10 @@ from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.utils import Log -@stereotype.pipeline("nbv_reconstruction_global_pts_n_num_pipeline") -class NBVReconstructionGlobalPointsPipeline(nn.Module): +@stereotype.pipeline("nbv_reconstruction_pipeline") +class NBVReconstructionPipeline(nn.Module): def __init__(self, config): - super(NBVReconstructionGlobalPointsPipeline, self).__init__() + super(NBVReconstructionPipeline, self).__init__() self.config = config self.module_config = config["modules"] @@ -20,10 +20,6 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): self.pose_encoder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["pose_encoder"] ) - self.pts_num_encoder = ComponentFactory.create( - namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] - ) - self.transformer_seq_encoder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"] ) @@ -96,44 +92,21 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): scanned_n_to_world_pose_9d_batch = data[ "scanned_n_to_world_pose_9d" ] # List(B): Tensor(S x 9) - scanned_pts_mask_batch = data[ - "scanned_pts_mask" - ] # Tensor(B x N) device = next(self.parameters()).device embedding_list_batch = [] combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3) - global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points( - combined_scanned_pts_batch, require_per_point_feat=True - ) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl) + global_scanned_feat = self.pts_encoder.encode_points( + combined_scanned_pts_batch, require_per_point_feat=False + ) # global_scanned_feat: Tensor(B x Dg) - for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip( - scanned_n_to_world_pose_9d_batch, - scanned_pts_mask_batch, - perpoint_scanned_feat_batch, - ): - scanned_target_pts_num = [] # List(S): Int - partial_feat_seq = [] - - seq_len = len(scanned_n_to_world_pose_9d) - for seq_idx in range(seq_len): - partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask - partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl) - partial_feat = torch.mean(partial_perpoint_feat, dim=0) # Tensor(Dl) - partial_feat_seq.append(partial_feat) - scanned_target_pts_num.append(partial_perpoint_feat.shape[0]) - - - scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.float32).unsqueeze(-1).to(device) # Tensor(S x 1) + for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch: scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9) - pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) - pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn) - partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl) - seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl)) - embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl)) - + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) + seq_embedding = pose_feat_seq + embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp)) seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))