42 lines
1.1 KiB
Python
Raw Normal View History

2024-10-09 16:13:22 +00:00
import torch.nn as nn
import torch
import torch.nn.functional as F
class RotHead(nn.Module):
def __init__(self, in_feat_dim, out_dim=3):
super(RotHead, self).__init__()
self.f = in_feat_dim
self.k = out_dim
self.conv1 = torch.nn.Conv1d(self.f, 1024, 1)
self.conv2 = torch.nn.Conv1d(1024, 256, 1)
self.conv3 = torch.nn.Conv1d(256, 256, 1)
self.conv4 = torch.nn.Conv1d(256, self.k, 1)
self.drop1 = nn.Dropout(0.2)
self.bn1 = nn.BatchNorm1d(1024)
self.bn2 = nn.BatchNorm1d(256)
self.bn3 = nn.BatchNorm1d(256)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = F.relu(self.bn3(self.conv3(x)))
x = self.drop1(x)
x = self.conv4(x)
x = x.squeeze(2)
x = x.contiguous()
return x
if __name__ == "__main__":
points = torch.rand(2, 1350, 1024) # batch_size x feature x num_of_point
rot_head = RotHead(in_feat_dim=1350, out_dim=3)
rot = rot_head(points)
print(rot.shape)