299 lines
7.3 KiB
Python
Raw Normal View History

2024-10-09 16:13:22 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch '''
import torch
import torch.nn as nn
from typing import List, Tuple
class SharedMLP(nn.Sequential):
def __init__(
self,
args: List[int],
*,
bn: bool = False,
activation=nn.ReLU(inplace=True),
preact: bool = False,
first: bool = False,
name: str = ""
):
super().__init__()
for i in range(len(args) - 1):
self.add_module(
name + 'layer{}'.format(i),
Conv2d(
args[i],
args[i + 1],
bn=(not first or not preact or (i != 0)) and bn,
activation=activation
if (not first or not preact or (i != 0)) else None,
preact=preact
)
)
class _BNBase(nn.Sequential):
def __init__(self, in_size, batch_norm=None, name=""):
super().__init__()
self.add_module(name + "bn", batch_norm(in_size))
nn.init.constant_(self[0].weight, 1.0)
nn.init.constant_(self[0].bias, 0)
class BatchNorm1d(_BNBase):
def __init__(self, in_size: int, *, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
class BatchNorm2d(_BNBase):
def __init__(self, in_size: int, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
class BatchNorm3d(_BNBase):
def __init__(self, in_size: int, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)
class _ConvBase(nn.Sequential):
def __init__(
self,
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=None,
batch_norm=None,
bias=True,
preact=False,
name=""
):
super().__init__()
bias = bias and (not bn)
conv_unit = conv(
in_size,
out_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
init(conv_unit.weight)
if bias:
nn.init.constant_(conv_unit.bias, 0)
if bn:
if not preact:
bn_unit = batch_norm(out_size)
else:
bn_unit = batch_norm(in_size)
if preact:
if bn:
self.add_module(name + 'bn', bn_unit)
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + 'conv', conv_unit)
if not preact:
if bn:
self.add_module(name + 'bn', bn_unit)
if activation is not None:
self.add_module(name + 'activation', activation)
class Conv1d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv1d,
batch_norm=BatchNorm1d,
bias=bias,
preact=preact,
name=name
)
class Conv2d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv2d,
batch_norm=BatchNorm2d,
bias=bias,
preact=preact,
name=name
)
class Conv3d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int, int] = (1, 1, 1),
stride: Tuple[int, int, int] = (1, 1, 1),
padding: Tuple[int, int, int] = (0, 0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv3d,
batch_norm=BatchNorm3d,
bias=bias,
preact=preact,
name=name
)
class FC(nn.Sequential):
def __init__(
self,
in_size: int,
out_size: int,
*,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=None,
preact: bool = False,
name: str = ""
):
super().__init__()
fc = nn.Linear(in_size, out_size, bias=not bn)
if init is not None:
init(fc.weight)
if not bn:
nn.init.constant_(fc.bias, 0)
if preact:
if bn:
self.add_module(name + 'bn', BatchNorm1d(in_size))
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + 'fc', fc)
if not preact:
if bn:
self.add_module(name + 'bn', BatchNorm1d(out_size))
if activation is not None:
self.add_module(name + 'activation', activation)
def set_bn_momentum_default(bn_momentum):
def fn(m):
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.momentum = bn_momentum
return fn
class BNMomentumScheduler(object):
def __init__(
self, model, bn_lambda, last_epoch=-1,
setter=set_bn_momentum_default
):
if not isinstance(model, nn.Module):
raise RuntimeError(
"Class '{}' is not a PyTorch nn Module".format(
type(model).__name__
)
)
self.model = model
self.setter = setter
self.lmbd = bn_lambda
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
self.model.apply(self.setter(self.lmbd(epoch)))