add transformer seq encoder and add seq_feat in gf_view_finder
This commit is contained in:
@@ -54,12 +54,12 @@ class GradientFieldViewFinder(ViewFinder):
|
||||
if not self.per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
@@ -72,15 +72,13 @@ class GradientFieldViewFinder(ViewFinder):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'target_pts_feat': [bs, c]
|
||||
'scene_pts_feat': [bs, c]
|
||||
'seq_feat': [bs, c]
|
||||
'pose_sample': [bs, pose_dim]
|
||||
't': [bs, 1]
|
||||
}
|
||||
"""
|
||||
|
||||
scene_pts_feat = data['scene_feat']
|
||||
target_pts_feat = data['target_feat']
|
||||
seq_feat = data['seq_feat']
|
||||
sampled_pose = data['sampled_pose']
|
||||
t = data['t']
|
||||
t_feat = self.t_encoder(t.squeeze(1))
|
||||
@@ -89,7 +87,7 @@ class GradientFieldViewFinder(ViewFinder):
|
||||
if self.per_point_feature:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1)
|
||||
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
|
||||
_, std = self.marginal_prob_fn(total_feat, t)
|
||||
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
@@ -106,20 +104,7 @@ class GradientFieldViewFinder(ViewFinder):
|
||||
|
||||
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
|
||||
|
||||
if self.sample_mode == 'pc':
|
||||
in_process_sample, res = flib.cond_pc_sampler(
|
||||
score_model=self,
|
||||
data=data,
|
||||
prior=self.prior_fn,
|
||||
sde_coeff=self.sde_fn,
|
||||
num_steps=self.sampling_steps,
|
||||
snr=snr,
|
||||
eps=self.sampling_eps,
|
||||
pose_mode=self.pose_mode,
|
||||
init_x=init_x
|
||||
)
|
||||
|
||||
elif self.sample_mode == 'ode':
|
||||
if self.sample_mode == 'ode':
|
||||
T0 = self.T if T0 is None else T0
|
||||
in_process_sample, res = flib.cond_ode_sampler(
|
||||
score_model=self,
|
||||
@@ -140,10 +125,9 @@ class GradientFieldViewFinder(ViewFinder):
|
||||
|
||||
return in_process_sample, res
|
||||
|
||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
||||
def next_best_view(self, seq_feat):
|
||||
data = {
|
||||
'scene_feat': scene_pts_feat,
|
||||
'target_feat': target_pts_feat,
|
||||
'seq_feat': seq_feat,
|
||||
}
|
||||
in_process_sample, res = self.sample(data)
|
||||
return res.to(dtype=torch.float32), in_process_sample
|
||||
|
Reference in New Issue
Block a user