"""Weight Standartized convolution layers, see https://arxiv.org/abs/1903.10520"""
# Adapted from https://github.com/vballoli/nfnets-pytorch/blob/61f0d6387/nfnets/base.py
# Added 3D versions of the WS layers
import torch
from torch import nn
from torch.functional import F
from torch import Tensor
from typing import Optional, List, Tuple
[docs]
class FWS(nn.Module):
"""Kind of like weight standardization, but changes weights in place"""
def __init__(
self,
layer: nn.Module,
learnable_gain: bool = True,
const_eval: bool = False,
eps: float = 1e-4
):
super().__init__()
self.layer = layer
self.const_eval = const_eval
self.vmdims: Tuple[int, ...] = tuple(range(1, self.layer.weight.ndim))
if learnable_gain:
self.gain = nn.Parameter(
torch.ones(self.layer.weight.shape[0], requires_grad=True)
)
else:
self.register_buffer('gain', torch.ones(self.layer.weight.shape[0]))
self.register_buffer('fan_in', torch.prod(torch.tensor(self.layer.weight.shape)))
self.register_buffer('eps', torch.tensor(eps))
[docs]
def standardize_weight(self): # inplace
var, mean = torch.var_mean(self.layer.weight, dim=self.vmdims, keepdims=True)
scale = torch.rsqrt(
torch.max(var * self.fan_in, self.eps)
) * self.gain.view_as(var)
shift = mean * scale
with torch.no_grad():
self.layer.weight.mul_(scale).sub_(shift)
# return self.layer.weight * scale - shift
[docs]
def forward(self, *args, **kwargs):
if self.training or not self.const_eval:
self.standardize_weight()
return self.layer(*args, **kwargs)
[docs]
class WSConv3d(nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(self.weight.shape[0], requires_grad=True))
[docs]
def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2, 3, 4), keepdims=True)
fan_in = torch.prod(torch.tensor(self.weight.shape))
scale = torch.rsqrt(
torch.max(
var * fan_in, torch.tensor(eps).to(var.device))
) * self.gain.view_as(var).to(var.device)
shift = mean * scale
return self.weight * scale - shift
[docs]
def forward(self, input, eps=1e-4):
weight = self.standardize_weight(eps)
return F.conv3d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]
class WSConvTranspose3d(nn.ConvTranspose3d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = 'zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode)
nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(self.weight.size(0), requires_grad=True))
[docs]
def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2, 3, 4), keepdims=True)
fan_in = torch.prod(torch.tensor(self.weight.shape))
scale = torch.rsqrt(torch.max(
var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
shift = mean * scale
return self.weight * scale - shift
[docs]
def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float = 1e-4) -> Tensor:
weight = self.standardize_weight(eps)
return F.conv_transpose2d(input, weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation)
[docs]
class WSConv1d(nn.Conv1d):
r"""Applies a 1D convolution over an input signal composed of several input
planes.
In the simplest case, the output value of the layer with input size
:math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
precisely described as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
\star \text{input}(N_i, k)
where :math:`\star` is the valid `cross-correlation`_ operator,
:math:`N` is a batch size, :math:`C` denotes a number of channels,
:math:`L` is a length of signal sequence.
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation, a single
number or a one-element tuple.
* :attr:`padding` controls the amount of implicit zero-paddings on both sides
for :attr:`padding` number of points.
* :attr:`dilation` controls the spacing between the kernel points; also
known as the à trous algorithm. It is harder to describe, but this `link`_
has a nice visualization of what :attr:`dilation` does.
* :attr:`groups` controls the connections between inputs and outputs.
:attr:`in_channels` and :attr:`out_channels` must both be divisible by
:attr:`groups`. For example,
* At groups=1, all inputs are convolved to all outputs.
* At groups=2, the operation becomes equivalent to having two conv
layers side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently
concatenated.
* At groups= :attr:`in_channels`, each input channel is convolved with
its own set of filters,
of size
:math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
Note:
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid
`cross-correlation`_, and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Note:
When `groups == in_channels` and `out_channels == K * in_channels`,
where `K` is a positive integer, this operation is also termed in
literature as depthwise convolution.
In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
:math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`.
Note:
In some circumstances when using the CUDA backend with CuDNN, this operator
may select a nondeterministic algorithm to increase performance. If this is
undesirable, you can try to make the operation deterministic (potentially at
a performance cost) by setting ``torch.backends.cudnn.deterministic =
True``.
Please see the notes on :doc:`/notes/randomness` for background.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel
elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`(N, C_{out}, L_{out})` where
.. math::
L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels},
\frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
bias (Tensor): the learnable bias of the module of shape
(out_channels). If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
Examples::
>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(
self.weight.size()[0], requires_grad=True))
[docs]
def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2), keepdims=True)
fan_in = torch.prod(torch.tensor(self.weight.shape))
scale = torch.rsqrt(torch.max(
var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
shift = mean * scale
return self.weight * scale - shift
[docs]
def forward(self, input, eps=1e-4):
weight = self.standardize_weight(eps)
return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]
class WSConv2d(nn.Conv2d):
"""Applies a 2D convolution over an input signal composed of several input
planes after weight normalization/standardization.
Reference: https://github.com/deepmind/deepmind-research/blob/master/nfnets/base.py#L121
In the simplest case, the output value of the layer with input size
:math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
can be precisely described as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
where :math:`\star` is the valid 2D `cross-correlation`_ operator,
:math:`N` is a batch size, :math:`C` denotes a number of channels,
:math:`H` is a height of input planes in pixels, and :math:`W` is
width in pixels.
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation, a single
number or a tuple.
* :attr:`padding` controls the amount of implicit zero-paddings on both
sides for :attr:`padding` number of points for each dimension.
* :attr:`dilation` controls the spacing between the kernel points; also
known as the à trous algorithm. It is harder to describe, but this `link`_
has a nice visualization of what :attr:`dilation` does.
* :attr:`groups` controls the connections between inputs and outputs.
:attr:`in_channels` and :attr:`out_channels` must both be divisible by
:attr:`groups`. For example,
* At groups=1, all inputs are convolved to all outputs.
* At groups=2, the operation becomes equivalent to having two conv
layers side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently
concatenated.
* At groups= :attr:`in_channels`, each input channel is convolved with
its own set of filters, of size:
:math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimension
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
and the second `int` for the width dimension
Note:
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Note:
When `groups == in_channels` and `out_channels == K * in_channels`,
where `K` is a positive integer, this operation is also termed in
literature as depthwise convolution.
In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
:math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
Note:
In some circumstances when using the CUDA backend with CuDNN, this operator
may select a nondeterministic algorithm to increase performance. If this is
undesirable, you can try to make the operation deterministic (potentially at
a performance cost) by setting ``torch.backends.cudnn.deterministic =
True``.
Please see the notes on :doc:`/notes/randomness` for background.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 0
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
\times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
\times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
bias (Tensor): the learnable bias of the module of shape
(out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
Examples:
>>> # With square kernels and equal stride
>>> m = WSConv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = WSConv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = WSConv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> output = m(input)
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(
self.weight.size(0), requires_grad=True))
[docs]
def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
fan_in = torch.prod(torch.tensor(self.weight.shape[0:]))
scale = torch.rsqrt(torch.max(
var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
shift = mean * scale
return self.weight * scale - shift
[docs]
def forward(self, input, eps=1e-4):
weight = self.standardize_weight(eps)
return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]
class WSConvTranspose2d(nn.ConvTranspose2d):
"""Applies a 2D transposed convolution operator over an input image
composed of several input planes after weight normalization/standardization.
This module can be seen as the gradient of Conv2d with respect to its input.
It is also known as a fractionally-strided convolution or
a deconvolution (although it is not an actual deconvolution operation).
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation.
* :attr:`padding` controls the amount of implicit zero-paddings on both
sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
below for details.
* :attr:`output_padding` controls the additional size added to one side
of the output shape. See note below for details.
* :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
* :attr:`groups` controls the connections between inputs and outputs.
:attr:`in_channels` and :attr:`out_channels` must both be divisible by
:attr:`groups`. For example,
* At groups=1, all inputs are convolved to all outputs.
* At groups=2, the operation becomes equivalent to having two conv
layers side by side, each seeing half the input channels,
and producing half the output channels, and both subsequently
concatenated.
* At groups= :attr:`in_channels`, each input channel is convolved with
its own set of filters (of size
:math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`).
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
can either be:
- a single ``int`` -- in which case the same value is used for the height and width dimensions
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
and the second `int` for the width dimension
.. note::
Depending of the size of your kernel, several (of the last)
columns of the input might be lost, because it is a valid `cross-correlation`_,
and not a full `cross-correlation`_.
It is up to the user to add proper padding.
Note:
The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
amount of zero padding to both sizes of the input. This is set so that
when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
are initialized with same parameters, they are inverses of each other in
regard to the input and output shapes. However, when ``stride > 1``,
:class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
shape. :attr:`output_padding` is provided to resolve this ambiguity by
effectively increasing the calculated output shape on one side. Note
that :attr:`output_padding` is only used to find output shape, but does
not actually add zero-padding to output.
Note:
In some circumstances when using the CUDA backend with CuDNN, this operator
may select a nondeterministic algorithm to increase performance. If this is
undesirable, you can try to make the operation deterministic (potentially at
a performance cost) by setting ``torch.backends.cudnn.deterministic =
True``.
Please see the notes on :doc:`/notes/randomness` for background.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
will be added to both sides of each dimension in the input. Default: 0
output_padding (int or tuple, optional): Additional size added to one side
of each dimension in the output shape. Default: 0
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
.. math::
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
\times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
.. math::
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
\times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
bias (Tensor): the learnable bias of the module of shape (out_channels)
If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
Examples::
>>> # With square kernels and equal stride
>>> m = WSConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = WSConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> input = torch.randn(20, 16, 50, 100)
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> input = torch.randn(1, 16, 12, 12)
>>> downsample = WSConv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = WSConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = 'zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode)
nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(
self.weight.size(0), requires_grad=True))
[docs]
def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
fan_in = torch.prod(torch.tensor(self.weight.shape[0:]))
scale = torch.rsqrt(torch.max(
var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
shift = mean * scale
return self.weight * scale - shift
[docs]
def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float = 1e-4) -> Tensor:
weight = self.standardize_weight(eps)
return F.conv_transpose2d(input, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation)