elektronn3.training.recalibration module

Normalization layer recalibration tools

Based on https://github.com/mdraw/contrib/blob/bdf4da5/torchcontrib/optim/swa.py

exception elektronn3.training.recalibration.NoApplicableLayersException[source]

Bases: Exception

elektronn3.training.recalibration.recalibrate_bn(loader, model, device=None)[source]

Returns a model with running_mean, running_var buffers of normalization layers recalibrated.

It performs one pass over data in loader to estimate the activation statistics for BatchNorm layers in the model.

Parameters:
  • loader (torch.utils.data.DataLoader) – dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data.

  • model (torch.nn.Module) – model for which we seek to update BatchNorm statistics.

  • device (torch.device, optional) – If set, data will be trasferred to device before being passed into model.