# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' import torch import torch.nn as nn from typing import List, Tuple class SharedMLP(nn.Sequential): def __init__( self, args: List[int], *, bn: bool = False, activation=nn.ReLU(inplace=True), preact: bool = False, first: bool = False, name: str = "" ): super().__init__() for i in range(len(args) - 1): self.add_module( name + 'layer{}'.format(i), Conv2d( args[i], args[i + 1], bn=(not first or not preact or (i != 0)) and bn, activation=activation if (not first or not preact or (i != 0)) else None, preact=preact ) ) class _BNBase(nn.Sequential): def __init__(self, in_size, batch_norm=None, name=""): super().__init__() self.add_module(name + "bn", batch_norm(in_size)) nn.init.constant_(self[0].weight, 1.0) nn.init.constant_(self[0].bias, 0) class BatchNorm1d(_BNBase): def __init__(self, in_size: int, *, name: str = ""): super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) class BatchNorm2d(_BNBase): def __init__(self, in_size: int, name: str = ""): super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) class BatchNorm3d(_BNBase): def __init__(self, in_size: int, name: str = ""): super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) class _ConvBase(nn.Sequential): def __init__( self, in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=None, batch_norm=None, bias=True, preact=False, name="" ): super().__init__() bias = bias and (not bn) conv_unit = conv( in_size, out_size, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) init(conv_unit.weight) if bias: nn.init.constant_(conv_unit.bias, 0) if bn: if not preact: bn_unit = batch_norm(out_size) else: bn_unit = batch_norm(in_size) if preact: if bn: self.add_module(name + 'bn', bn_unit) if activation is not None: self.add_module(name + 'activation', activation) self.add_module(name + 'conv', conv_unit) if not preact: if bn: self.add_module(name + 'bn', bn_unit) if activation is not None: self.add_module(name + 'activation', activation) class Conv1d(_ConvBase): def __init__( self, in_size: int, out_size: int, *, kernel_size: int = 1, stride: int = 1, padding: int = 0, activation=nn.ReLU(inplace=True), bn: bool = False, init=nn.init.kaiming_normal_, bias: bool = True, preact: bool = False, name: str = "" ): super().__init__( in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=nn.Conv1d, batch_norm=BatchNorm1d, bias=bias, preact=preact, name=name ) class Conv2d(_ConvBase): def __init__( self, in_size: int, out_size: int, *, kernel_size: Tuple[int, int] = (1, 1), stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), activation=nn.ReLU(inplace=True), bn: bool = False, init=nn.init.kaiming_normal_, bias: bool = True, preact: bool = False, name: str = "" ): super().__init__( in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=nn.Conv2d, batch_norm=BatchNorm2d, bias=bias, preact=preact, name=name ) class Conv3d(_ConvBase): def __init__( self, in_size: int, out_size: int, *, kernel_size: Tuple[int, int, int] = (1, 1, 1), stride: Tuple[int, int, int] = (1, 1, 1), padding: Tuple[int, int, int] = (0, 0, 0), activation=nn.ReLU(inplace=True), bn: bool = False, init=nn.init.kaiming_normal_, bias: bool = True, preact: bool = False, name: str = "" ): super().__init__( in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=nn.Conv3d, batch_norm=BatchNorm3d, bias=bias, preact=preact, name=name ) class FC(nn.Sequential): def __init__( self, in_size: int, out_size: int, *, activation=nn.ReLU(inplace=True), bn: bool = False, init=None, preact: bool = False, name: str = "" ): super().__init__() fc = nn.Linear(in_size, out_size, bias=not bn) if init is not None: init(fc.weight) if not bn: nn.init.constant_(fc.bias, 0) if preact: if bn: self.add_module(name + 'bn', BatchNorm1d(in_size)) if activation is not None: self.add_module(name + 'activation', activation) self.add_module(name + 'fc', fc) if not preact: if bn: self.add_module(name + 'bn', BatchNorm1d(out_size)) if activation is not None: self.add_module(name + 'activation', activation) def set_bn_momentum_default(bn_momentum): def fn(m): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): m.momentum = bn_momentum return fn class BNMomentumScheduler(object): def __init__( self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default ): if not isinstance(model, nn.Module): raise RuntimeError( "Class '{}' is not a PyTorch nn Module".format( type(model).__name__ ) ) self.model = model self.setter = setter self.lmbd = bn_lambda self.step(last_epoch + 1) self.last_epoch = last_epoch def step(self, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = epoch self.model.apply(self.setter(self.lmbd(epoch)))