import MinkowskiEngine as ME from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck from models.resnet import ResNetBase class MinkUNetBase(ResNetBase): BLOCK = None PLANES = None DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) PLANES = (32, 64, 128, 256, 256, 128, 96, 96) INIT_DIM = 32 OUT_TENSOR_STRIDE = 1 # To use the model, must call initialize_coords before forward pass. # Once data is processed, call clear to reset the model before calling # initialize_coords def __init__(self, in_channels, out_channels, D=3): ResNetBase.__init__(self, in_channels, out_channels, D) def network_initialization(self, in_channels, out_channels, D): # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv0p1s1 = ME.MinkowskiConvolution( in_channels, self.inplanes, kernel_size=5, dimension=D) self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) self.conv1p1s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) self.conv2p2s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) self.conv3p4s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) self.conv4p8s2 = ME.MinkowskiConvolution( self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D) self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4]) self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D) self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5]) self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D) self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6]) self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D) self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) self.inplanes = self.PLANES[7] + self.INIT_DIM self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7]) self.final = ME.MinkowskiConvolution( self.PLANES[7] * self.BLOCK.expansion, out_channels, kernel_size=1, bias=True, dimension=D) self.relu = ME.MinkowskiReLU(inplace=True) def forward(self, x): out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) # tensor_stride=16 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) out = self.block4(out) # tensor_stride=8 out = self.convtr4p16s2(out) out = self.bntr4(out) out = self.relu(out) out = ME.cat(out, out_b3p8) out = self.block5(out) # tensor_stride=4 out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) out = ME.cat(out, out_b2p4) out = self.block6(out) # tensor_stride=2 out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) out = ME.cat(out, out_b1p2) out = self.block7(out) # tensor_stride=1 out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) out = ME.cat(out, out_p1) out = self.block8(out) return self.final(out) class MinkUNet14(MinkUNetBase): BLOCK = BasicBlock LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) class MinkUNet18(MinkUNetBase): BLOCK = BasicBlock LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) class MinkUNet34(MinkUNetBase): BLOCK = BasicBlock LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) class MinkUNet50(MinkUNetBase): BLOCK = Bottleneck LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) class MinkUNet101(MinkUNetBase): BLOCK = Bottleneck LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) class MinkUNet14A(MinkUNet14): PLANES = (32, 64, 128, 256, 128, 128, 96, 96) class MinkUNet14B(MinkUNet14): PLANES = (32, 64, 128, 256, 128, 128, 128, 128) class MinkUNet14C(MinkUNet14): PLANES = (32, 64, 128, 256, 192, 192, 128, 128) class MinkUNet14Dori(MinkUNet14): PLANES = (32, 64, 128, 256, 384, 384, 384, 384) class MinkUNet14E(MinkUNet14): PLANES = (32, 64, 128, 256, 384, 384, 384, 384) class MinkUNet14D(MinkUNet14): PLANES = (32, 64, 128, 256, 192, 192, 192, 192) class MinkUNet18A(MinkUNet18): PLANES = (32, 64, 128, 256, 128, 128, 96, 96) class MinkUNet18B(MinkUNet18): PLANES = (32, 64, 128, 256, 128, 128, 128, 128) class MinkUNet18D(MinkUNet18): PLANES = (32, 64, 128, 256, 384, 384, 384, 384) class MinkUNet34A(MinkUNet34): PLANES = (32, 64, 128, 256, 256, 128, 64, 64) class MinkUNet34B(MinkUNet34): PLANES = (32, 64, 128, 256, 256, 128, 64, 32) class MinkUNet34C(MinkUNet34): PLANES = (32, 64, 128, 256, 256, 128, 96, 96)