add transformer seq encoder and add seq_feat in gf_view_finder

This commit is contained in:
hofee
2024-08-22 00:24:28 +08:00
parent 68b4325dbd
commit b06dede4b8
3 changed files with 88 additions and 252 deletions

View File

@@ -54,12 +54,12 @@ class GradientFieldViewFinder(ViewFinder):
if not self.per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(128 + 256 + 1024 + 1024, 256),
nn.Linear(128 + 256 + 2048, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(128 + 256 + 1024 + 1024, 256),
nn.Linear(128 + 256 + 2048, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
@@ -72,15 +72,13 @@ class GradientFieldViewFinder(ViewFinder):
"""
Args:
data, dict {
'target_pts_feat': [bs, c]
'scene_pts_feat': [bs, c]
'seq_feat': [bs, c]
'pose_sample': [bs, pose_dim]
't': [bs, 1]
}
"""
scene_pts_feat = data['scene_feat']
target_pts_feat = data['target_feat']
seq_feat = data['seq_feat']
sampled_pose = data['sampled_pose']
t = data['t']
t_feat = self.t_encoder(t.squeeze(1))
@@ -89,7 +87,7 @@ class GradientFieldViewFinder(ViewFinder):
if self.per_point_feature:
raise NotImplementedError
else:
total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1)
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
_, std = self.marginal_prob_fn(total_feat, t)
if self.regression_head == 'Rx_Ry':
@@ -106,20 +104,7 @@ class GradientFieldViewFinder(ViewFinder):
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
if self.sample_mode == 'pc':
in_process_sample, res = flib.cond_pc_sampler(
score_model=self,
data=data,
prior=self.prior_fn,
sde_coeff=self.sde_fn,
num_steps=self.sampling_steps,
snr=snr,
eps=self.sampling_eps,
pose_mode=self.pose_mode,
init_x=init_x
)
elif self.sample_mode == 'ode':
if self.sample_mode == 'ode':
T0 = self.T if T0 is None else T0
in_process_sample, res = flib.cond_ode_sampler(
score_model=self,
@@ -140,10 +125,9 @@ class GradientFieldViewFinder(ViewFinder):
return in_process_sample, res
def next_best_view(self, scene_pts_feat, target_pts_feat):
def next_best_view(self, seq_feat):
data = {
'scene_feat': scene_pts_feat,
'target_feat': target_pts_feat,
'seq_feat': seq_feat,
}
in_process_sample, res = self.sample(data)
return res.to(dtype=torch.float32), in_process_sample