elektronn3.training.swa module¶
- class elektronn3.training.swa.SWA(*args: Any, **kwargs: Any)[source]¶
Bases:
Optimizer
- add_param_group(param_group)[source]¶
Add a param group to the
Optimizer
s param_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizer
as training progresses.- Parameters:
param_group (dict) – Specifies what Tensors should be optimized along
options. (with group specific optimization) –
- static bn_update(loader, model, device=None)[source]¶
Updates BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in loader to estimate the activation statistics for BatchNorm layers in the model.
- Parameters:
loader (torch.utils.data.DataLoader) – dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data.
model (torch.nn.Module) – model for which we seek to update BatchNorm statistics.
device (torch.device, optional) – If set, data will be trasferred to
device
before being passed intomodel
.
- load_state_dict(state_dict)[source]¶
Loads the optimizer state.
- Parameters:
state_dict (dict) – SWA optimizer state. Should be an object returned from a call to state_dict.
- state_dict()[source]¶
Returns the state of SWA as a
dict
.- It contains three entries:
- opt_state - a dict holding current optimization state of the base
optimizer. Its content differs between optimizer classes.
- swa_state - a dict containing current state of SWA. For each
optimized variable it contains swa_buffer keeping the running average of the variable
param_groups - a dict containing all parameter groups
- step(closure=None)[source]¶
Performs a single optimization step.
In automatic mode also updates SWA running averages.
- swap_swa_sgd()[source]¶
Swaps the values of the optimized variables and swa buffers.
It’s meant to be called in the end of training to use the collected swa running averages. It can also be used to evaluate the running averages during training; to continue training swap_swa_sgd should be called again.
- update_swa_group(group)[source]¶
Updates the SWA running averages for the given parameter group.
- Parameters:
param_group (dict) – Specifies for what parameter group SWA running averages should be updated
Examples
>>> # automatic mode >>> base_opt = torch.optim.SGD([{'params': [x]}, >>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9) >>> opt = torchcontrib.optim.SWA(base_opt) >>> for i in range(100): >>> opt.zero_grad() >>> loss_fn(model(input), target).backward() >>> opt.step() >>> if i > 10 and i % 5 == 0: >>> # Update SWA for the second parameter group >>> opt.update_swa_group(opt.param_groups[1]) >>> opt.swap_swa_sgd()