elektronn3.training.trainer module¶
- class elektronn3.training.trainer.Backup(script_path, save_path, extra_content=None)[source]¶
Bases:
object
Backup class for archiving training script, src folder and environment info. Should be used for any future archiving needs.
- Parameters:
script_path – The path to the training script. Eg. train_unet_neurodata.py
save_path – The path where the information is archived.
extra_content – Dictionary of {filename: content} entries, where content is a string that should be written to a file with the specified name.
- archive_backup()[source]¶
Archiving the source folder, the training script and environment info.
The training script is saved with the prefix “0-” to distinguish from regular scripts. Environment information equivalent to the output of
python -m torch.utils.collect_env
is saved in a file named “env_info.txt”.
- exception elektronn3.training.trainer.NaNException[source]¶
Bases:
RuntimeError
When a NaN value is detected
- class elektronn3.training.trainer.Trainer(model, criterion, optimizer, device, save_root, train_dataset, valid_dataset=None, unlabeled_dataset=None, valid_metrics=None, ss_criterion=None, preview_batch=None, knossos_preview_config=None, preview_interval=5, inference_kwargs=None, hparams=None, extra_save_steps=(), exp_name=None, example_input=None, enable_save_trace=False, save_jit=None, batch_size=1, num_workers=0, schedulers=None, overlay_alpha=0.4, enable_videos=False, enable_tensorboard=True, tensorboard_root_path=None, ignore_errors=False, ipython_shell=False, out_channels=None, sample_plotting_handler=None, preview_plotting_handler=None, mixed_precision=False, tqdm_kwargs=None)[source]¶
Bases:
object
General training loop abstraction for supervised training.
- Parameters:
model (
Module
) – PyTorch model (nn.Module
) that shall be trained. Please make sure that the output shape of themodel
matches the shape of targets that are delivered by thetrain_dataset
.criterion (
Module
) – PyTorch loss that shall be used as the optimization criterion.optimizer (
Optimizer
) – PyTorch optimizer that shall be used to updatemodel
weights according to thecriterion
in each iteration.device (
device
) – The device on which the network shall be trained.train_dataset (
Dataset
) – PyTorch dataset (data.Dataset
) which produces training samples when iterated over.elektronn3.data.cnndata.PatchCreator
is currently recommended for constructing datasets.valid_dataset (
Optional
[Dataset
]) – PyTorch dataset (data.Dataset
) which produces validation samples when iterated over. The length (len(valid_dataset)
) of it determines how many samples are used for one validation metric calculation.unlabeled_dataset (
Optional
[Dataset
]) – Unlabeled dataset (only inputs) for semi-supervised training. If this is supplied,ss_criterion
needs to be set to the loss that should be computed on unlabeled inputs.valid_metrics (
Optional
[Dict
]) – Validation metrics to be calculated on validation data after each training epoch. All metrics are logged to tensorboard.ss_criterion (
Optional
[Module
]) – Loss criterion for the self-supervised part of semi-supervised training. Thess_criterion
loss is computed on batches from theunlabeled_dataset
and added to the supervised loss in each training step.save_root (
str
) – Root directory where training-related files are stored. Files are always written to the subdirectorysave_root/exp_name/
.exp_name (
Optional
[str
]) – Name of the training experiment. Determines the subdirectory to which files are written and should uniquely identify one training experiment. Ifexp_name
is not set, it is auto-generated from the model name and a time stamp in the format'%y-%m-%d_%H-%M-%S'
.example_input (
Optional
[Tensor
]) – An example input tensor that can be fed to themodel
. This is used for JIT tracing during model serialization.save_jit (
Optional
[str
]) –Chooses if/how a JIT version (.pts file) of the
model
should always be saved in addition to regular model snapshots. Choices:None
(default): Disable saving JIT models.'script'
(recommended if possible): The model is compiled directly withtorch.jit.script()
and saved as a .pts file'trace'
: The model is JIT-traced withexample_input
and saved as a .pts file
batch_size (
int
) – Desired batch size of training samples.preview_batch (
Optional
[Tensor
]) – Set a fixed input batch for preview predictions. If it isNone
(default), preview batch functionality will be disabled. As a more powerful alternative for KNOSSOS datasets, consider using theknossos_preview_config
option instead.knossos_preview_config (
Optional
[Dict
[str
,str
]]) –Configures preview batch creation and preview inferences based on a KNOSSOS dataset region. Here is an example of how it should look like:
>>> knossos_preview_config = { ... 'dataset': 'path/to/knossos/dataset.conf', ... 'offset': [0, 0, 0], # Offset (min) coordinates ... 'size': [256, 256, 64], # Size (shape) of the region ... 'mag': 1, # source mag ... 'target_mags': [1, 2, 3], # List of target mags to which the inference results should be written ... 'remap_ids': None # Config for class ID remapping (optional). See transforms.RemapTargetIDs ... }
Periodic preview inference results are written to .k.zip annotation files that can be loaded with KNOSSOS and overlayed over the original data. .k.zip files are saved in the training directory, with file names reflecting their training step.
preview_interval (
int
) – Determines how often to perform preview inference. Preview inference is performed everypreview_interval
epochs during training. Regardless of this value, preview predictions will also be performed once after epoch 1. (To disable preview predictions altogether, just setpreview_batch = None
).inference_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional options that are supplied to theelektronn3.inference.Predictor
instance that is used for periodic preview inference on thepreview_batch
.extra_save_steps (
Sequence
[int
]) – Permanent model snapshots are saved at the training steps specified here. E.g. withextra_save_steps = (0, 30, 3000)
, a snapshot is made at steps 0 (before training begins), step 30 and step 3000.num_workers (
int
) – Number of background processes that are used to produce training samples without blocking the main training loop. Seetorch.utils.data.DataLoader
For normal training, you can mostly setnum_workers=1
. Only use more workers if you notice a data loader bottleneck. Setnum_workers=0
if you want to debug the datasets implementation, to avoid mulitprocessing-specific issues.schedulers (
Optional
[Dict
[Any
,Any
]]) – Dictionary of schedulers for training hyperparameters, e.g. learning rate schedulers that can be found in py:mod:`torch.optim.lr_scheduler.overlay_alpha (
float
) – Alpha (transparency) value for alpha-blending of overlay image plots.enable_videos (
bool
) – Enables video visualizations for 3D image data in tensorboard. Requires the moviepy package. Warning: Videos are stored as GIFs and can get very large, so only use this if you log seldomly or have a lot of storage capacity.enable_tensorboard (
bool
) – IfTrue
, tensorboard logging/plotting is enabled during training.tensorboard_root_path (
Optional
[str
]) – Path to the root directory under which tensorboard log directories are created. Log (“event”) files are written to a subdirectory that has the same name as theexp_name
. Iftensorboard_root_path
is not set, tensorboard logs are written tosave_path
(next to model checkpoints, plots etc.).ignore_errors (
bool
) – IfTrue
, the training process tries to ignore all errors and continue with the next batch if it encounters an error on the current batch. It’s not recommended to use this. It’s only helpful for certain debugging scenarios.ipython_shell (
bool
) – IfTrue
keyboard interrupts (Ctrl-C) won’t exit the process but only pause training and enter an IPython shell. Additionally, errors during training (except C-level segfaults etc.) won’t crash the whole training process, but drop to an IPython shell so errors can be inspected with access to the current training state.out_channels (
Optional
[int
]) – Optionally specifies the total number of different target classes for classification tasks. If this is not set manually, theTrainer
checks if thetrain_dataset
provides this value. If available,self.out_channels
is set toself.train_dataset.out_channels
. Otherwise, it is set toNone
. Theout_channels
attribute is used for plotting purposes and is not strictly required for training.sample_plotting_handler (
Optional
[Callable
]) – Function that receives training and validation samples and is responsible for visualizing them by e.g. plotting them to tensorboard and/or writing them to files. It is called once after each training epoch and once after each validation run. IfNone
, a tensorboard-based default handler is used that works for most classification scenarios and for 3-channel regression.preview_plotting_handler (
Optional
[Callable
]) – Function that is responsible for producing previews and visualizing/plotting/logging them. It is called once eachpreview_interval
epochs. IfNone
, a tensorboard-based default handler is used that works for most classification scenarios.mixed_precision (
bool
) – IfTrue
, enable Automated Mixed Precision training powered by https://github.com/NVIDIA/apex to reduce memory usage and (if a GPU with Tensor Cores is used) make training much faster. This is currently experimental and might cause instabilities.tqdm_kwargs (
Optional
[Dict
]) – Extra arguments to be passed to tqdm progress bars. For example, to disable tqdm outputs completely, passtqdm_kwargs={'disable': True}
.
-
epoch:
int
¶
-
exp_name:
str
¶
-
out_channels:
Optional
[int
]¶
- run(max_steps=1, max_runtime=604800)[source]¶
Train the network for
max_steps
steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.- Return type:
None
-
save_path:
str
¶
-
step:
int
¶
-
tb:
SummaryWriter
¶
-
terminate:
bool
¶
-
train_loader:
DataLoader
¶
-
valid_loader:
DataLoader
¶
- elektronn3.training.trainer.findcudatensors()[source]¶
Find currently living tensors that are allocated on cuda device memory. This can be used for debugging memory leaks: If
findcudatensors()[0]
grows unexpectedly between GPU computations, you can look at the returnedtensors
list to find out what tensors are currently allocated, for exampleprint([x.shape for x in findcudatensors()[1])
.Returns a tuple of :rtype:
Tuple
[int
,List
[Tensor
]]total memory usage of found tensors in MiB
a list of all of those tensors, ordered by size.