add per_points_encoder

This commit is contained in:
hofee 2024-09-29 20:12:44 +08:00
parent 2753f114a3
commit f42e45d608

View File

@ -22,12 +22,10 @@ class PointNetEncoder(nn.Module):
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 512, 1)
self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1)
self.global_feat = config["global_feat"]
if self.feature_transform:
self.f_stn = STNkd(k=64)
def forward(self, x):
n_pts = x.shape[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
@ -46,20 +44,15 @@ class PointNetEncoder(nn.Module):
x = self.conv4(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, self.out_dim)
if self.global_feat:
return x
else:
x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts)
return torch.cat([x, point_feat], 1)
return x, point_feat
def encode_points(self, pts):
def encode_points(self, pts, require_per_point_feat=False):
pts = pts.transpose(2, 1)
if not self.global_feat:
pts_feature = self(pts).transpose(2, 1)
global_pts_feature, per_point_feature = self(pts)
if require_per_point_feat:
return global_pts_feature, per_point_feature.transpose(2, 1)
else:
pts_feature = self(pts)
return pts_feature
return global_pts_feature
class STNkd(nn.Module):
def __init__(self, k=64):
@ -102,21 +95,13 @@ if __name__ == "__main__":
config = {
"in_dim": 3,
"out_dim": 1024,
"global_feat": True,
"feature_transform": False
}
pointnet_global = PointNetEncoder(config)
out = pointnet_global.encode_points(sim_data)
pointnet = PointNetEncoder(config)
out = pointnet.encode_points(sim_data)
print("global feat", out.size())
config = {
"in_dim": 3,
"out_dim": 1024,
"global_feat": False,
"feature_transform": False
}
pointnet = PointNetEncoder(config)
out = pointnet.encode_points(sim_data)
out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True)
print("point feat", out.size())
print("per point feat", per_point_out.size())