Source code for elektronn3.modules.wsconv

"""Weight Standartized convolution layers, see https://arxiv.org/abs/1903.10520"""

# Adapted from https://github.com/vballoli/nfnets-pytorch/blob/61f0d6387/nfnets/base.py
# Added 3D versions of the WS layers

import torch
from torch import nn
from torch.functional import F
from torch import Tensor

from typing import Optional, List, Tuple


[docs] class FWS(nn.Module): """Kind of like weight standardization, but changes weights in place""" def __init__( self, layer: nn.Module, learnable_gain: bool = True, const_eval: bool = False, eps: float = 1e-4 ): super().__init__() self.layer = layer self.const_eval = const_eval self.vmdims: Tuple[int, ...] = tuple(range(1, self.layer.weight.ndim)) if learnable_gain: self.gain = nn.Parameter( torch.ones(self.layer.weight.shape[0], requires_grad=True) ) else: self.register_buffer('gain', torch.ones(self.layer.weight.shape[0])) self.register_buffer('fan_in', torch.prod(torch.tensor(self.layer.weight.shape))) self.register_buffer('eps', torch.tensor(eps))
[docs] def standardize_weight(self): # inplace var, mean = torch.var_mean(self.layer.weight, dim=self.vmdims, keepdims=True) scale = torch.rsqrt( torch.max(var * self.fan_in, self.eps) ) * self.gain.view_as(var) shift = mean * scale with torch.no_grad(): self.layer.weight.mul_(scale).sub_(shift)
# return self.layer.weight * scale - shift
[docs] def forward(self, *args, **kwargs): if self.training or not self.const_eval: self.standardize_weight() return self.layer(*args, **kwargs)
[docs] class WSConv3d(nn.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) nn.init.kaiming_normal_(self.weight) self.gain = nn.Parameter(torch.ones(self.weight.shape[0], requires_grad=True))
[docs] def standardize_weight(self, eps): var, mean = torch.var_mean(self.weight, dim=(1, 2, 3, 4), keepdims=True) fan_in = torch.prod(torch.tensor(self.weight.shape)) scale = torch.rsqrt( torch.max( var * fan_in, torch.tensor(eps).to(var.device)) ) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift
[docs] def forward(self, input, eps=1e-4): weight = self.standardize_weight(eps) return F.conv3d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs] class WSConvTranspose3d(nn.ConvTranspose3d): def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, output_padding=0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = 'zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode) nn.init.kaiming_normal_(self.weight) self.gain = nn.Parameter(torch.ones(self.weight.size(0), requires_grad=True))
[docs] def standardize_weight(self, eps): var, mean = torch.var_mean(self.weight, dim=(1, 2, 3, 4), keepdims=True) fan_in = torch.prod(torch.tensor(self.weight.shape)) scale = torch.rsqrt(torch.max( var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift
[docs] def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float = 1e-4) -> Tensor: weight = self.standardize_weight(eps) return F.conv_transpose2d(input, weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation)
[docs] class WSConv1d(nn.Conv1d): r"""Applies a 1D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be precisely described as: .. math:: \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) where :math:`\star` is the valid `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`L` is a length of signal sequence. This module supports :ref:`TensorFloat32<tf32_on_ampere>`. * :attr:`stride` controls the stride for the cross-correlation, a single number or a one-element tuple. * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. * :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by :attr:`groups`. For example, * At groups=1, all inputs are convolved to all outputs. * At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. * At groups= :attr:`in_channels`, each input channel is convolved with its own set of filters, of size :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. Note: Depending of the size of your kernel, several (of the last) columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. Note: When `groups == in_channels` and `out_channels == K * in_channels`, where `K` is a positive integer, this operation is also termed in literature as depthwise convolution. In other words, for an input of size :math:`(N, C_{in}, L_{in})`, a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments :math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`. Note: In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. Please see the notes on :doc:`/notes/randomness` for background. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input: :math:`(N, C_{in}, L_{in})` - Output: :math:`(N, C_{out}, L_{out})` where .. math:: L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` Examples:: >>> m = nn.Conv1d(16, 33, 3, stride=2) >>> input = torch.randn(20, 16, 50) >>> output = m(input) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) nn.init.kaiming_normal_(self.weight) self.gain = nn.Parameter(torch.ones( self.weight.size()[0], requires_grad=True))
[docs] def standardize_weight(self, eps): var, mean = torch.var_mean(self.weight, dim=(1, 2), keepdims=True) fan_in = torch.prod(torch.tensor(self.weight.shape)) scale = torch.rsqrt(torch.max( var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift
[docs] def forward(self, input, eps=1e-4): weight = self.standardize_weight(eps) return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs] class WSConv2d(nn.Conv2d): """Applies a 2D convolution over an input signal composed of several input planes after weight normalization/standardization. Reference: https://github.com/deepmind/deepmind-research/blob/master/nfnets/base.py#L121 In the simplest case, the output value of the layer with input size :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` can be precisely described as: .. math:: \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) where :math:`\star` is the valid 2D `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. This module supports :ref:`TensorFloat32<tf32_on_ampere>`. * :attr:`stride` controls the stride for the cross-correlation, a single number or a tuple. * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. * :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by :attr:`groups`. For example, * At groups=1, all inputs are convolved to all outputs. * At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. * At groups= :attr:`in_channels`, each input channel is convolved with its own set of filters, of size: :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension Note: Depending of the size of your kernel, several (of the last) columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. Note: When `groups == in_channels` and `out_channels == K * in_channels`, where `K` is a positive integer, this operation is also termed in literature as depthwise convolution. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. Note: In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. Please see the notes on :doc:`/notes/randomness` for background. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` Examples: >>> # With square kernels and equal stride >>> m = WSConv2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = WSConv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> # non-square kernels and unequal stride and with padding and dilation >>> m = WSConv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) >>> input = torch.randn(20, 16, 50, 100) >>> output = m(input) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) nn.init.kaiming_normal_(self.weight) self.gain = nn.Parameter(torch.ones( self.weight.size(0), requires_grad=True))
[docs] def standardize_weight(self, eps): var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True) fan_in = torch.prod(torch.tensor(self.weight.shape[0:])) scale = torch.rsqrt(torch.max( var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift
[docs] def forward(self, input, eps=1e-4): weight = self.standardize_weight(eps) return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs] class WSConvTranspose2d(nn.ConvTranspose2d): """Applies a 2D transposed convolution operator over an input image composed of several input planes after weight normalization/standardization. This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation). This module supports :ref:`TensorFloat32<tf32_on_ampere>`. * :attr:`stride` controls the stride for the cross-correlation. * :attr:`padding` controls the amount of implicit zero-paddings on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. * :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by :attr:`groups`. For example, * At groups=1, all inputs are convolved to all outputs. * At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated. * At groups= :attr:`in_channels`, each input channel is convolved with its own set of filters (of size :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimensions - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension .. note:: Depending of the size of your kernel, several (of the last) columns of the input might be lost, because it is a valid `cross-correlation`_, and not a full `cross-correlation`_. It is up to the user to add proper padding. Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d` are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when ``stride > 1``, :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output shape. :attr:`output_padding` is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that :attr:`output_padding` is only used to find output shape, but does not actually add zero-padding to output. Note: In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. Please see the notes on :doc:`/notes/randomness` for background. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of each dimension in the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of each dimension in the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 .. math:: W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels) If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` Examples:: >>> # With square kernels and equal stride >>> m = WSConvTranspose2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = WSConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> input = torch.randn(20, 16, 50, 100) >>> output = m(input) >>> # exact output size can be also specified as an argument >>> input = torch.randn(1, 16, 12, 12) >>> downsample = WSConv2d(16, 16, 3, stride=2, padding=1) >>> upsample = WSConvTranspose2d(16, 16, 3, stride=2, padding=1) >>> h = downsample(input) >>> h.size() torch.Size([1, 16, 6, 6]) >>> output = upsample(h, output_size=input.size()) >>> output.size() torch.Size([1, 16, 12, 12]) .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, output_padding=0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = 'zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode) nn.init.kaiming_normal_(self.weight) self.gain = nn.Parameter(torch.ones( self.weight.size(0), requires_grad=True))
[docs] def standardize_weight(self, eps): var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True) fan_in = torch.prod(torch.tensor(self.weight.shape[0:])) scale = torch.rsqrt(torch.max( var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device) shift = mean * scale return self.weight * scale - shift
[docs] def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float = 1e-4) -> Tensor: weight = self.standardize_weight(eps) return F.conv_transpose2d(input, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation)