elektronn3.training.swa module

class elektronn3.training.swa.SWA(*args: Any, **kwargs: Any)[source]

Bases: torch.optim.


Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

  • param_group (dict) – Specifies what Tensors should be optimized along

  • options. (with group specific optimization) –

static bn_update(loader, model, device=None)[source]

Updates BatchNorm running_mean, running_var buffers in the model.

It performs one pass over data in loader to estimate the activation statistics for BatchNorm layers in the model.

  • loader (torch.utils.data.DataLoader) – dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data.

  • model (torch.nn.Module) – model for which we seek to update BatchNorm statistics.

  • device (torch.device, optional) – If set, data will be trasferred to device before being passed into model.


Loads the optimizer state.


state_dict (dict) – SWA optimizer state. Should be an object returned from a call to state_dict.


Returns the state of SWA as a dict.

It contains three entries:
  • opt_state - a dict holding current optimization state of the base

    optimizer. Its content differs between optimizer classes.

  • swa_state - a dict containing current state of SWA. For each

    optimized variable it contains swa_buffer keeping the running average of the variable

  • param_groups - a dict containing all parameter groups


Performs a single optimization step.

In automatic mode also updates SWA running averages.


Swaps the values of the optimized variables and swa buffers.

It’s meant to be called in the end of training to use the collected swa running averages. It can also be used to evaluate the running averages during training; to continue training swap_swa_sgd should be called again.


Updates the SWA running averages of all optimized parameters.


Updates the SWA running averages for the given parameter group.


param_group (dict) – Specifies for what parameter group SWA running averages should be updated


>>> # automatic mode
>>> base_opt = torch.optim.SGD([{'params': [x]},
>>>             {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
>>> opt = torchcontrib.optim.SWA(base_opt)
>>> for i in range(100):
>>>     opt.zero_grad()
>>>     loss_fn(model(input), target).backward()
>>>     opt.step()
>>>     if i > 10 and i % 5 == 0:
>>>         # Update SWA for the second parameter group
>>>         opt.update_swa_group(opt.param_groups[1])
>>> opt.swap_swa_sgd()