global_only: pipeline
This commit is contained in:
parent
a21538c90a
commit
f533104e4a
@ -116,16 +116,16 @@ module:
|
|||||||
feature_transform: False
|
feature_transform: False
|
||||||
|
|
||||||
transformer_seq_encoder:
|
transformer_seq_encoder:
|
||||||
embed_dim: 384
|
embed_dim: 256
|
||||||
num_heads: 4
|
num_heads: 4
|
||||||
ffn_dim: 256
|
ffn_dim: 256
|
||||||
num_layers: 3
|
num_layers: 3
|
||||||
output_dim: 2048
|
output_dim: 1024
|
||||||
|
|
||||||
gf_view_finder:
|
gf_view_finder:
|
||||||
t_feat_dim: 128
|
t_feat_dim: 128
|
||||||
pose_feat_dim: 256
|
pose_feat_dim: 256
|
||||||
main_feat_dim: 3072
|
main_feat_dim: 2048
|
||||||
regression_head: Rx_Ry_and_T
|
regression_head: Rx_Ry_and_T
|
||||||
pose_mode: rot_matrix
|
pose_mode: rot_matrix
|
||||||
per_point_feature: False
|
per_point_feature: False
|
||||||
|
@ -7,10 +7,10 @@ from PytorchBoot.factory.component_factory import ComponentFactory
|
|||||||
from PytorchBoot.utils import Log
|
from PytorchBoot.utils import Log
|
||||||
|
|
||||||
|
|
||||||
@stereotype.pipeline("nbv_reconstruction_global_pts_n_num_pipeline")
|
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
||||||
class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
class NBVReconstructionPipeline(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
super(NBVReconstructionPipeline, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.module_config = config["modules"]
|
self.module_config = config["modules"]
|
||||||
|
|
||||||
@ -20,10 +20,6 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
|||||||
self.pose_encoder = ComponentFactory.create(
|
self.pose_encoder = ComponentFactory.create(
|
||||||
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
||||||
)
|
)
|
||||||
self.pts_num_encoder = ComponentFactory.create(
|
|
||||||
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.transformer_seq_encoder = ComponentFactory.create(
|
self.transformer_seq_encoder = ComponentFactory.create(
|
||||||
namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"]
|
namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"]
|
||||||
)
|
)
|
||||||
@ -96,44 +92,21 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
|||||||
scanned_n_to_world_pose_9d_batch = data[
|
scanned_n_to_world_pose_9d_batch = data[
|
||||||
"scanned_n_to_world_pose_9d"
|
"scanned_n_to_world_pose_9d"
|
||||||
] # List(B): Tensor(S x 9)
|
] # List(B): Tensor(S x 9)
|
||||||
scanned_pts_mask_batch = data[
|
|
||||||
"scanned_pts_mask"
|
|
||||||
] # Tensor(B x N)
|
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
embedding_list_batch = []
|
embedding_list_batch = []
|
||||||
|
|
||||||
combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3)
|
combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3)
|
||||||
global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points(
|
global_scanned_feat = self.pts_encoder.encode_points(
|
||||||
combined_scanned_pts_batch, require_per_point_feat=True
|
combined_scanned_pts_batch, require_per_point_feat=False
|
||||||
) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl)
|
) # global_scanned_feat: Tensor(B x Dg)
|
||||||
|
|
||||||
for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip(
|
for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch:
|
||||||
scanned_n_to_world_pose_9d_batch,
|
|
||||||
scanned_pts_mask_batch,
|
|
||||||
perpoint_scanned_feat_batch,
|
|
||||||
):
|
|
||||||
scanned_target_pts_num = [] # List(S): Int
|
|
||||||
partial_feat_seq = []
|
|
||||||
|
|
||||||
seq_len = len(scanned_n_to_world_pose_9d)
|
|
||||||
for seq_idx in range(seq_len):
|
|
||||||
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
|
|
||||||
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
|
|
||||||
partial_feat = torch.mean(partial_perpoint_feat, dim=0) # Tensor(Dl)
|
|
||||||
partial_feat_seq.append(partial_feat)
|
|
||||||
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])
|
|
||||||
|
|
||||||
|
|
||||||
scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.float32).unsqueeze(-1).to(device) # Tensor(S x 1)
|
|
||||||
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9)
|
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9)
|
||||||
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp)
|
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp)
|
||||||
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn)
|
seq_embedding = pose_feat_seq
|
||||||
partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl)
|
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp))
|
||||||
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
|
|
||||||
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
|
|
||||||
|
|
||||||
|
|
||||||
seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
||||||
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user