after first overfit test

This commit is contained in:
2024-09-18 06:49:59 +00:00
parent d80d0ea79d
commit 0280dc7292
6 changed files with 193 additions and 42 deletions

View File

@@ -5,7 +5,7 @@ import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory
from PytorchBoot.utils import Log
@stereotype.pipeline("nbv_reconstruction_pipeline")
@stereotype.pipeline("nbv_reconstruction_pipeline", comment="should be tested")
class NBVReconstructionPipeline(nn.Module):
def __init__(self, config):
super(NBVReconstructionPipeline, self).__init__()
@@ -72,10 +72,14 @@ class NBVReconstructionPipeline(nn.Module):
pose_feat_seq_list = []
for scanned_pts,scanned_n_to_1_pose_9d in zip(scanned_pts_batch,scanned_n_to_1_pose_9d_batch):
scanned_pts = scanned_pts.to(best_to_1_pose_9d_batch.device)
scanned_n_to_1_pose_9d = scanned_n_to_1_pose_9d.to(best_to_1_pose_9d_batch.device)
pts_feat_seq_list.append(self.pts_encoder.encode_points(scanned_pts))
pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_1_pose_9d))
seq_feat = self.seq_encoder.encode_sequence(pts_feat_seq_list, pose_feat_seq_list)
if torch.isnan(seq_feat).any():
Log.error("nan in seq_feat", True)
return seq_feat