From f42e45d60842d29220229c25e5a880ad8bf774c6 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:12:44 +0800 Subject: [PATCH] add per_points_encoder --- modules/pointnet_encoder.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/modules/pointnet_encoder.py b/modules/pointnet_encoder.py index 6483709..6e414f2 100644 --- a/modules/pointnet_encoder.py +++ b/modules/pointnet_encoder.py @@ -22,12 +22,10 @@ class PointNetEncoder(nn.Module): self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 512, 1) self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1) - self.global_feat = config["global_feat"] if self.feature_transform: self.f_stn = STNkd(k=64) def forward(self, x): - n_pts = x.shape[2] trans = self.stn(x) x = x.transpose(2, 1) x = torch.bmm(x, trans) @@ -46,20 +44,15 @@ class PointNetEncoder(nn.Module): x = self.conv4(x) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, self.out_dim) - if self.global_feat: - return x - else: - x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) - return torch.cat([x, point_feat], 1) + return x, point_feat - def encode_points(self, pts): + def encode_points(self, pts, require_per_point_feat=False): pts = pts.transpose(2, 1) - - if not self.global_feat: - pts_feature = self(pts).transpose(2, 1) + global_pts_feature, per_point_feature = self(pts) + if require_per_point_feat: + return global_pts_feature, per_point_feature.transpose(2, 1) else: - pts_feature = self(pts) - return pts_feature + return global_pts_feature class STNkd(nn.Module): def __init__(self, k=64): @@ -102,21 +95,13 @@ if __name__ == "__main__": config = { "in_dim": 3, "out_dim": 1024, - "global_feat": True, "feature_transform": False } - pointnet_global = PointNetEncoder(config) - out = pointnet_global.encode_points(sim_data) + pointnet = PointNetEncoder(config) + out = pointnet.encode_points(sim_data) print("global feat", out.size()) - config = { - "in_dim": 3, - "out_dim": 1024, - "global_feat": False, - "feature_transform": False - } - - pointnet = PointNetEncoder(config) - out = pointnet.encode_points(sim_data) + out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True) print("point feat", out.size()) + print("per point feat", per_point_out.size())