Backup class for archiving training script, src folder and environment info. Should be used for any future archiving needs.
script_path – The path to the training script. Eg. train_unet_neurodata.py
save_path – The path where the information is archived.
Archiving the source folder, the training script and environment info.
The training script is saved with the prefix “0-” to distinguish from regular scripts. Environment information equivalent to the output of
python -m torch.utils.collect_envis saved in a file named “env_info.txt”.
When a NaN value is detected
Trainer(model, criterion, optimizer, device, save_root, train_dataset, valid_dataset=None, unlabeled_dataset=None, valid_metrics=None, ss_criterion=None, preview_batch=None, preview_interval=5, inference_kwargs=None, hparams=None, extra_save_steps=(), exp_name=None, example_input=None, enable_save_trace=False, save_jit=None, batch_size=1, num_workers=0, schedulers=None, overlay_alpha=0.4, enable_videos=False, enable_tensorboard=True, tensorboard_root_path=None, ignore_errors=False, ipython_shell=False, out_channels=None, sample_plotting_handler=None, preview_plotting_handler=None, mixed_precision=False)¶
General training loop abstraction for supervised training.
Module) – PyTorch model (
nn.Module) that shall be trained. Please make sure that the output shape of the
modelmatches the shape of targets that are delivered by the
Module) – PyTorch loss that shall be used as the optimization criterion.
Optimizer) – PyTorch optimizer that shall be used to update
modelweights according to the
criterionin each iteration.
device) – The device on which the network shall be trained.
Dataset) – PyTorch dataset (
data.Dataset) which produces training samples when iterated over.
elektronn3.data.cnndata.PatchCreatoris currently recommended for constructing datasets.
Dataset]) – PyTorch dataset (
data.Dataset) which produces validation samples when iterated over. The length (
len(valid_dataset)) of it determines how many samples are used for one validation metric calculation.
Dataset]) – Unlabeled dataset (only inputs) for semi-supervised training. If this is supplied,
ss_criterionneeds to be set to the loss that should be computed on unlabeled inputs.
Dict]) – Validation metrics to be calculated on validation data after each training epoch. All metrics are logged to tensorboard.
Module]) – Loss criterion for the self-supervised part of semi-supervised training. The
ss_criterionloss is computed on batches from the
unlabeled_datasetand added to the supervised loss in each training step.
str) – Root directory where training-related files are stored. Files are always written to the subdirectory
str]) – Name of the training experiment. Determines the subdirectory to which files are written and should uniquely identify one training experiment. If
exp_nameis not set, it is auto-generated from the model name and a time stamp in the format
Tensor]) – An example input tensor that can be fed to the
model. This is used for JIT tracing during model serialization.
Chooses if/how a JIT version (.pts file) of the
modelshould always be saved in addition to regular model snapshots. Choices:
None(default): Disable saving JIT models.
'script'(recommended if possible): The model is compiled directly with
torch.jit.script()and saved as a .pts file
'trace': The model is JIT-traced with
example_inputand saved as a .pts file
int) – Desired batch size of training samples.
Tensor]) – Set a fixed input batch for preview predictions. If it is
None(default), preview batch functionality will be disabled.
int) – Determines how often to perform preview inference. Preview inference is performed every
preview_intervalepochs during training. Regardless of this value, preview predictions will also be performed once after epoch 1. (To disable preview predictions altogether, just set
preview_batch = None).
Any]]) – Additional options that are supplied to the
elektronn3.inference.Predictorinstance that is used for periodic preview inference on the
int]) – Permanent model snapshots are saved at the training steps specified here. E.g. with
extra_save_at_steps = (0, 30, 3000), a snapshot is made at steps 0 (before training begins), step 30 and step 3000.
int) – Number of background processes that are used to produce training samples without blocking the main training loop. See
torch.utils.data.DataLoaderFor normal training, you can mostly set
num_workers=1. Only use more workers if you notice a data loader bottleneck. Set
num_workers=0if you want to debug the datasets implementation, to avoid mulitprocessing-specific issues.
Any]]) – Dictionary of schedulers for training hyperparameters, e.g. learning rate schedulers that can be found in py:mod:`torch.optim.lr_scheduler.
float) – Alpha (transparency) value for alpha-blending of overlay image plots.
bool) – Enables video visualizations for 3D image data in tensorboard. Requires the moviepy package. Warning: Videos are stored as GIFs and can get very large, so only use this if you log seldomly or have a lot of storage capacity.
bool) – If
True, tensorboard logging/plotting is enabled during training.
str]) – Path to the root directory under which tensorboard log directories are created. Log (“event”) files are written to a subdirectory that has the same name as the
tensorboard_root_pathis not set, tensorboard logs are written to
save_path(next to model checkpoints, plots etc.).
bool) – If
True, the training process tries to ignore all errors and continue with the next batch if it encounters an error on the current batch. It’s not recommended to use this. It’s only helpful for certain debugging scenarios.
bool) – If
Truekeyboard interrupts (Ctrl-C) won’t exit the process but only pause training and enter an IPython shell. Additionally, errors during training (except C-level segfaults etc.) won’t crash the whole training process, but drop to an IPython shell so errors can be inspected with access to the current training state.
int]) – Optionally specifies the total number of different target classes for classification tasks. If this is not set manually, the
Trainerchecks if the
train_datasetprovides this value. If available,
self.out_channelsis set to
self.train_dataset.out_channels. Otherwise, it is set to
out_channelsattribute is used for plotting purposes and is not strictly required for training.
Callable]) – Function that receives training and validation samples and is responsible for visualizing them by e.g. plotting them to tensorboard and/or writing them to files. It is called once after each training epoch and once after each validation run. If
None, a tensorboard-based default handler is used that works for most classification scenarios and for 3-channel regression.
Callable]) – Function that is responsible for producing previews and visualizing/plotting/logging them. It is called once each
None, a tensorboard-based default handler is used that works for most classification scenarios.
bool) – If
True, enable Automated Mixed Precision training powered by https://github.com/NVIDIA/apex to reduce memory usage and (if a GPU with Tensor Cores is used) make training much faster. This is currently experimental and might cause instabilities.
Train the network for
max_stepssteps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.
- Return type
Find currently living tensors that are allocated on cuda device memory. This can be used for debugging memory leaks: If
findcudatensors()grows unexpectedly between GPU computations, you can look at the returned
tensorslist to find out what tensors are currently allocated, for example
print([x.shape for x in findcudatensors()).
Returns a tuple of
total memory usage of found tensors in MiB
a list of all of those tensors, ordered by size.
- Return type