elektronn3.training.trainer module

class elektronn3.training.trainer.Backup(script_path, save_path)[source]

Bases: object

Backup class for archiving training script, src folder and environment info. Should be used for any future archiving needs.

  • script_path – The path to the training script. Eg. train_unet_neurodata.py

  • save_path – The path where the information is archived.


Archiving the source folder, the training script and environment info.

The training script is saved with the prefix “0-” to distinguish from regular scripts. Environment information equivalent to the output of python -m torch.utils.collect_env is saved in a file named “env_info.txt”.

exception elektronn3.training.trainer.NaNException[source]

Bases: RuntimeError

When a NaN value is detected

class elektronn3.training.trainer.Trainer(model, criterion, optimizer, device, save_root, train_dataset, valid_dataset=None, unlabeled_dataset=None, valid_metrics=None, ss_criterion=None, preview_batch=None, knossos_preview_config=None, preview_interval=5, inference_kwargs=None, hparams=None, extra_save_steps=(), exp_name=None, example_input=None, enable_save_trace=False, save_jit=None, batch_size=1, num_workers=0, schedulers=None, overlay_alpha=0.4, enable_videos=False, enable_tensorboard=True, tensorboard_root_path=None, ignore_errors=False, ipython_shell=False, out_channels=None, sample_plotting_handler=None, preview_plotting_handler=None, mixed_precision=False, tqdm_kwargs=None)[source]

Bases: object

General training loop abstraction for supervised training.

  • model (Module) – PyTorch model (nn.Module) that shall be trained. Please make sure that the output shape of the model matches the shape of targets that are delivered by the train_dataset.

  • criterion (Module) – PyTorch loss that shall be used as the optimization criterion.

  • optimizer (Optimizer) – PyTorch optimizer that shall be used to update model weights according to the criterion in each iteration.

  • device (device) – The device on which the network shall be trained.

  • train_dataset (Dataset) – PyTorch dataset (data.Dataset) which produces training samples when iterated over. elektronn3.data.cnndata.PatchCreator is currently recommended for constructing datasets.

  • valid_dataset (Optional[Dataset]) – PyTorch dataset (data.Dataset) which produces validation samples when iterated over. The length (len(valid_dataset)) of it determines how many samples are used for one validation metric calculation.

  • unlabeled_dataset (Optional[Dataset]) – Unlabeled dataset (only inputs) for semi-supervised training. If this is supplied, ss_criterion needs to be set to the loss that should be computed on unlabeled inputs.

  • valid_metrics (Optional[Dict]) – Validation metrics to be calculated on validation data after each training epoch. All metrics are logged to tensorboard.

  • ss_criterion (Optional[Module]) – Loss criterion for the self-supervised part of semi-supervised training. The ss_criterion loss is computed on batches from the unlabeled_dataset and added to the supervised loss in each training step.

  • save_root (str) – Root directory where training-related files are stored. Files are always written to the subdirectory save_root/exp_name/.

  • exp_name (Optional[str]) – Name of the training experiment. Determines the subdirectory to which files are written and should uniquely identify one training experiment. If exp_name is not set, it is auto-generated from the model name and a time stamp in the format '%y-%m-%d_%H-%M-%S'.

  • example_input (Optional[Tensor]) – An example input tensor that can be fed to the model. This is used for JIT tracing during model serialization.

  • save_jit (Optional[str]) –

    Chooses if/how a JIT version (.pts file) of the model should always be saved in addition to regular model snapshots. Choices:

    • None (default): Disable saving JIT models.

    • 'script' (recommended if possible): The model is compiled directly with torch.jit.script() and saved as a .pts file

    • 'trace': The model is JIT-traced with example_input and saved as a .pts file

  • batch_size (int) – Desired batch size of training samples.

  • preview_batch (Optional[Tensor]) – Set a fixed input batch for preview predictions. If it is None (default), preview batch functionality will be disabled. As a more powerful alternative for KNOSSOS datasets, consider using the knossos_preview_config option instead.

  • knossos_preview_config (Optional[Dict[str, str]]) –

    Configures preview batch creation and preview inferences based on a KNOSSOS dataset region. Here is an example of how it should look like:

    >>> knossos_preview_config = {
    ...     'dataset': 'path/to/knossos/dataset.conf',
    ...     'offset': [0, 0, 0],  # Offset (min) coordinates
    ...     'size': [256, 256, 64],  # Size (shape) of the region
    ...     'mag': 1,  # source mag
    ...     'target_mags': [1, 2, 3],  # List of target mags to which the inference results should be written
    ...     'remap_ids': None  # Config for class ID remapping (optional). See transforms.RemapTargetIDs
    ... }

    Periodic preview inference results are written to .k.zip annotation files that can be loaded with KNOSSOS and overlayed over the original data. .k.zip files are saved in the training directory, with file names reflecting their training step.

  • preview_interval (int) – Determines how often to perform preview inference. Preview inference is performed every preview_interval epochs during training. Regardless of this value, preview predictions will also be performed once after epoch 1. (To disable preview predictions altogether, just set preview_batch = None).

  • inference_kwargs (Optional[Dict[str, Any]]) – Additional options that are supplied to the elektronn3.inference.Predictor instance that is used for periodic preview inference on the preview_batch.

  • extra_save_steps (Sequence[int]) – Permanent model snapshots are saved at the training steps specified here. E.g. with extra_save_at_steps = (0, 30, 3000), a snapshot is made at steps 0 (before training begins), step 30 and step 3000.

  • num_workers (int) – Number of background processes that are used to produce training samples without blocking the main training loop. See torch.utils.data.DataLoader For normal training, you can mostly set num_workers=1. Only use more workers if you notice a data loader bottleneck. Set num_workers=0 if you want to debug the datasets implementation, to avoid mulitprocessing-specific issues.

  • schedulers (Optional[Dict[Any, Any]]) – Dictionary of schedulers for training hyperparameters, e.g. learning rate schedulers that can be found in py:mod:`torch.optim.lr_scheduler.

  • overlay_alpha (float) – Alpha (transparency) value for alpha-blending of overlay image plots.

  • enable_videos (bool) – Enables video visualizations for 3D image data in tensorboard. Requires the moviepy package. Warning: Videos are stored as GIFs and can get very large, so only use this if you log seldomly or have a lot of storage capacity.

  • enable_tensorboard (bool) – If True, tensorboard logging/plotting is enabled during training.

  • tensorboard_root_path (Optional[str]) – Path to the root directory under which tensorboard log directories are created. Log (“event”) files are written to a subdirectory that has the same name as the exp_name. If tensorboard_root_path is not set, tensorboard logs are written to save_path (next to model checkpoints, plots etc.).

  • ignore_errors (bool) – If True, the training process tries to ignore all errors and continue with the next batch if it encounters an error on the current batch. It’s not recommended to use this. It’s only helpful for certain debugging scenarios.

  • ipython_shell (bool) – If True keyboard interrupts (Ctrl-C) won’t exit the process but only pause training and enter an IPython shell. Additionally, errors during training (except C-level segfaults etc.) won’t crash the whole training process, but drop to an IPython shell so errors can be inspected with access to the current training state.

  • out_channels (Optional[int]) – Optionally specifies the total number of different target classes for classification tasks. If this is not set manually, the Trainer checks if the train_dataset provides this value. If available, self.out_channels is set to self.train_dataset.out_channels. Otherwise, it is set to None. The out_channels attribute is used for plotting purposes and is not strictly required for training.

  • sample_plotting_handler (Optional[Callable]) – Function that receives training and validation samples and is responsible for visualizing them by e.g. plotting them to tensorboard and/or writing them to files. It is called once after each training epoch and once after each validation run. If None, a tensorboard-based default handler is used that works for most classification scenarios and for 3-channel regression.

  • preview_plotting_handler (Optional[Callable]) – Function that is responsible for producing previews and visualizing/plotting/logging them. It is called once each preview_interval epochs. If None, a tensorboard-based default handler is used that works for most classification scenarios.

  • mixed_precision (bool) – If True, enable Automated Mixed Precision training powered by https://github.com/NVIDIA/apex to reduce memory usage and (if a GPU with Tensor Cores is used) make training much faster. This is currently experimental and might cause instabilities.

  • tqdm_kwargs (Optional[Dict]) – Extra arguments to be passed to tqdm progress bars. For example, to disable tqdm outputs completely, pass tqdm_kwargs={'disable': True}.

epoch: int
exp_name: str
out_channels: Optional[int]
run(max_steps=1, max_runtime=604800)[source]

Train the network for max_steps steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.

Return type


save_path: str
step: int
tb: tensorboardX.SummaryWriter
terminate: bool
train_loader: torch.utils.data.DataLoader
valid_loader: torch.utils.data.DataLoader

Find currently living tensors that are allocated on cuda device memory. This can be used for debugging memory leaks: If findcudatensors()[0] grows unexpectedly between GPU computations, you can look at the returned tensors list to find out what tensors are currently allocated, for example print([x.shape for x in findcudatensors()[1]).

Returns a tuple of

  • total memory usage of found tensors in MiB

  • a list of all of those tensors, ordered by size.

Return type

Tuple[int, List[Tensor]]