elektronn3.training.recalibration module¶
Normalization layer recalibration tools
Based on https://github.com/mdraw/contrib/blob/bdf4da5/torchcontrib/optim/swa.py
- elektronn3.training.recalibration.recalibrate_bn(loader, model, device=None)[source]¶
Returns a model with running_mean, running_var buffers of normalization layers recalibrated.
It performs one pass over data in loader to estimate the activation statistics for BatchNorm layers in the model.
- Parameters:
loader (torch.utils.data.DataLoader) – dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data.
model (torch.nn.Module) – model for which we seek to update BatchNorm statistics.
device (torch.device, optional) – If set, data will be trasferred to
device
before being passed intomodel
.