Source code for elektronn3.modules.l1batchnorm

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Martin Drawitsch

import torch
from torch import nn


# TODO: ScriptModule
# @torch._jit_internal.weak_module
# class L1BatchNorm(torch.jit.ScriptModule):
[docs] class L1BatchNorm(nn.Module): """L1-Norm-based Batch Normalization module. Use with caution. This code is not extensively tested. References: - https://arxiv.org/abs/1802.09769 - https://arxiv.org/abs/1803.01814 """ __constants__ = ['l2factor', 'eps', 'momentum'] def __init__(self, num_features: int, momentum: float = 0.9): super().__init__() self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.zeros(num_features)) self.momentum = momentum self.gamma = nn.Parameter(torch.ones(1, num_features)) self.beta = nn.Parameter(torch.zeros(1, num_features)) self.eps = 1e-5 self.l2factor = (3.1416 / 2) ** 0.5 # @torch._jit_internal.weak_script # @torch.jit.script_method
[docs] def forward(self, x): ndim = x.dim() # If this is known statically, this module can be a ScriptModule reduce_dims = (0, 2, 3, 4)[:ndim] b_sh = (1, x.shape[1], 1, 1, 1)[:ndim] # Broadcastable shape if self.training: mean = x.mean(dim=reduce_dims, keepdim=True) meandiff = x - mean absdiff = meandiff.abs() l1mean = absdiff.mean(dim=reduce_dims, keepdim=True) l1scaled = l1mean * self.l2factor + self.eps with torch.no_grad(): # Update running stats mom = self.momentum self.running_mean.mul_(mom).add_(mean.flatten() * (1 - mom)) self.running_var.mul_(mom).add_(l1scaled.flatten() * (1 - mom)) else: mean = self.running_mean.view(b_sh) l1scaled = self.running_var.view(b_sh) meandiff = x - mean gamma = self.gamma.view(b_sh) beta = self.beta.view(b_sh) return gamma * meandiff / l1scaled + beta
# @torch._jit_internal.weak_script
[docs] def l1_group_norm(x, num_groups, weight, bias, eps): l2factor = 1.2533 # == (pi / 2) ** 0.5 ndim = x.dim() sh = x.shape g = num_groups n, c = sh[:2] # grouped_sh = (n, g, c // g, d, h, w) grouped_sh = (n, g, c // g, *sh[2:]) # Split C dimension into groups grouped = x.view(grouped_sh) reduce_dims = (2, 3, 4, 5)[:ndim - 1] # Reduce over grouped channels and spatial dimensions mean = grouped.mean(dim=reduce_dims, keepdim=True) meandiff = grouped - mean absdiff = meandiff.abs() l1mean = absdiff.mean(dim=reduce_dims, keepdim=True) l1scaled = l1mean * l2factor + eps normalized = meandiff / l1scaled normalized = normalized.view(sh) broadcast_sh = (1, c, 1, 1, 1)[:ndim] # Shape broadcastable over all dims of x weight = weight.view(broadcast_sh) bias = bias.view(broadcast_sh) return weight * normalized + bias
# @torch._jit_internal.weak_module
[docs] class L1GroupNorm(nn.GroupNorm): r"""Applies L1 Group Normalization over a mini-batch of inputs. This works in the same way as `torch.nn.GroupNorm`, but uses the scaled L1 norm instead of the L2 norm for better numerical stability, performance and half precision support. L1 *batch* normalization was proposed in - https://arxiv.org/abs/1802.09769 - https://arxiv.org/abs/1803.01814 This layer uses statistics computed from input data in both training and evaluation modes. Args: num_groups (int): number of groups to separate the channels into num_channels (int): number of channels expected in input eps: a value added to the denominator for numerical stability. Default: 1e-5 affine: a boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. Shape: - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` - Output: :math:`(N, C, *)` (same shape as input) .. _`Group Normalization`: https://arxiv.org/abs/1803.08494 """ __constants__ = ['num_groups', 'num_channels', 'eps', 'affine', 'weight', 'bias'] def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): super().__init__(num_groups, num_channels, eps, affine) print('Warning: L1 Group norm is experimental and may have issues.') @torch._jit_internal.weak_script_method def forward(self, input): return l1_group_norm(input, self.num_groups, self.weight, self.bias, self.eps)