fix bug for training

This commit is contained in:
2024-09-12 15:11:09 +08:00
parent a79ca7749d
commit 4c69ed777b
15 changed files with 201 additions and 120 deletions

View File

@@ -54,6 +54,7 @@ class PointNetEncoder(nn.Module):
def encode_points(self, pts):
pts = pts.transpose(2, 1)
if not self.global_feat:
pts_feature = self(pts).transpose(2, 1)
else:
@@ -98,11 +99,24 @@ class STNkd(nn.Module):
if __name__ == "__main__":
sim_data = Variable(torch.rand(32, 2500, 3))
pointnet_global = PointNetEncoder(global_feat=True)
config = {
"in_dim": 3,
"out_dim": 1024,
"global_feat": True,
"feature_transform": False
}
pointnet_global = PointNetEncoder(config)
out = pointnet_global.encode_points(sim_data)
print("global feat", out.size())
pointnet = PointNetEncoder(global_feat=False)
config = {
"in_dim": 3,
"out_dim": 1024,
"global_feat": False,
"feature_transform": False
}
pointnet = PointNetEncoder(config)
out = pointnet.encode_points(sim_data)
print("point feat", out.size())