Source code for elektronn3.training.trainer

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Martin Drawitsch, Philipp Schubert
import datetime
import pprint
from collections import deque

import gc
import logging
import os
import shutil
import warnings
import zipfile

from itertools import islice
from math import nan
from pickle import PickleError
from textwrap import dedent
from typing import Tuple, Dict, Optional, Callable, Any, Sequence, List, Union

import inspect
import IPython
import numpy as np
import tensorboardX
import torch
import torch.utils.data
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.cuda import amp
from tqdm import tqdm


import elektronn3
from elektronn3.training import handlers
from elektronn3.training.swa import SWA
from elektronn3.training.train_utils import pretty_string_time, create_preview_batch_from_knossos
from elektronn3.training.train_utils import Timer, HistoryTracker

from torch.utils import collect_env
from elektronn3.inference import Predictor
from elektronn3 import __file__ as arch_src

logger = logging.getLogger('elektronn3log')


[docs] class NaNException(RuntimeError): """When a NaN value is detected""" pass
def _worker_init_fn(worker_id: int) -> None: """Sets a unique but deterministic random seed for background workers. Only sets the seed for NumPy because PyTorch and Python's own RNGs take care of reseeding on their own. See https://github.com/numpy/numpy/issues/9650.""" # Modulo 2**32 because np.random.seed() only accepts values up to 2**32 - 1 initial_seed = torch.initial_seed() % 2**32 worker_seed = initial_seed + worker_id np.random.seed(worker_seed) # Be careful from where you call this! Not sure if this is concurrency-safe. def _change_log_file_to( new_path: str, transfer_old_logs: bool = True, delete_old_file: bool = True ) -> None: """Transfer the current log file to a new location and redirect logs.""" def _get_first_file_handler() -> logging.FileHandler: for handler in logger.handlers: if isinstance(handler, logging.FileHandler): return handler return RuntimeError('logger has no FileHandler.') # Getting the first (and presumably only) file handler file_handler = _get_first_file_handler() if transfer_old_logs: with open(file_handler.baseFilename) as f: old_logs = f.read() with open(new_path, 'w') as f: f.write(old_logs) file_handler.close() if delete_old_file: os.remove(file_handler.baseFilename) file_handler.baseFilename = new_path
[docs] class Trainer: """ General training loop abstraction for supervised training. Args: model: PyTorch model (``nn.Module``) that shall be trained. Please make sure that the output shape of the ``model`` matches the shape of targets that are delivered by the ``train_dataset``. criterion: PyTorch loss that shall be used as the optimization criterion. optimizer: PyTorch optimizer that shall be used to update ``model`` weights according to the ``criterion`` in each iteration. device: The device on which the network shall be trained. train_dataset: PyTorch dataset (``data.Dataset``) which produces training samples when iterated over. :py:class:`elektronn3.data.cnndata.PatchCreator` is currently recommended for constructing datasets. valid_dataset: PyTorch dataset (``data.Dataset``) which produces validation samples when iterated over. The length (``len(valid_dataset)``) of it determines how many samples are used for one validation metric calculation. unlabeled_dataset: Unlabeled dataset (only inputs) for semi-supervised training. If this is supplied, ``ss_criterion`` needs to be set to the loss that should be computed on unlabeled inputs. valid_metrics: Validation metrics to be calculated on validation data after each training epoch. All metrics are logged to tensorboard. ss_criterion: Loss criterion for the self-supervised part of semi-supervised training. The ``ss_criterion`` loss is computed on batches from the ``unlabeled_dataset`` and added to the supervised loss in each training step. save_root: Root directory where training-related files are stored. Files are always written to the subdirectory ``save_root/exp_name/``. exp_name: Name of the training experiment. Determines the subdirectory to which files are written and should uniquely identify one training experiment. If ``exp_name`` is not set, it is auto-generated from the model name and a time stamp in the format ``'%y-%m-%d_%H-%M-%S'``. example_input: An example input tensor that can be fed to the ``model``. This is used for JIT tracing during model serialization. save_jit: Chooses if/how a JIT version (.pts file) of the ``model`` should always be saved in addition to regular model snapshots. Choices: - ``None`` (default): Disable saving JIT models. - ``'script'`` (recommended if possible): The model is compiled directly with ``torch.jit.script()`` and saved as a .pts file - ``'trace'``: The model is JIT-traced with ``example_input`` and saved as a .pts file batch_size: Desired batch size of training samples. preview_batch: Set a fixed input batch for preview predictions. If it is ``None`` (default), preview batch functionality will be disabled. As a more powerful alternative for KNOSSOS datasets, consider using the ``knossos_preview_config`` option instead. knossos_preview_config: Configures preview batch creation and preview inferences based on a KNOSSOS dataset region. Here is an example of how it should look like: >>> knossos_preview_config = { ... 'dataset': 'path/to/knossos/dataset.conf', ... 'offset': [0, 0, 0], # Offset (min) coordinates ... 'size': [256, 256, 64], # Size (shape) of the region ... 'mag': 1, # source mag ... 'target_mags': [1, 2, 3], # List of target mags to which the inference results should be written ... 'remap_ids': None # Config for class ID remapping (optional). See transforms.RemapTargetIDs ... } Periodic preview inference results are written to .k.zip annotation files that can be loaded with KNOSSOS and overlayed over the original data. .k.zip files are saved in the training directory, with file names reflecting their training step. preview_interval: Determines how often to perform preview inference. Preview inference is performed every ``preview_interval`` epochs during training. Regardless of this value, preview predictions will also be performed once after epoch 1. (To disable preview predictions altogether, just set ``preview_batch = None``). inference_kwargs: Additional options that are supplied to the :py:class:`elektronn3.inference.Predictor` instance that is used for periodic preview inference on the ``preview_batch``. extra_save_steps: Permanent model snapshots are saved at the training steps specified here. E.g. with ``extra_save_steps = (0, 30, 3000)``, a snapshot is made at steps 0 (before training begins), step 30 and step 3000. num_workers: Number of background processes that are used to produce training samples without blocking the main training loop. See :py:class:`torch.utils.data.DataLoader` For normal training, you can mostly set ``num_workers=1``. Only use more workers if you notice a data loader bottleneck. Set ``num_workers=0`` if you want to debug the datasets implementation, to avoid mulitprocessing-specific issues. schedulers: Dictionary of schedulers for training hyperparameters, e.g. learning rate schedulers that can be found in `py:mod:`torch.optim.lr_scheduler`. overlay_alpha: Alpha (transparency) value for alpha-blending of overlay image plots. enable_videos: Enables video visualizations for 3D image data in tensorboard. Requires the moviepy package. Warning: Videos are stored as GIFs and can get very large, so only use this if you log seldomly or have a lot of storage capacity. enable_tensorboard: If ``True``, tensorboard logging/plotting is enabled during training. tensorboard_root_path: Path to the root directory under which tensorboard log directories are created. Log ("event") files are written to a subdirectory that has the same name as the ``exp_name``. If ``tensorboard_root_path`` is not set, tensorboard logs are written to ``save_path`` (next to model checkpoints, plots etc.). ignore_errors: If ``True``, the training process tries to ignore all errors and continue with the next batch if it encounters an error on the current batch. It's not recommended to use this. It's only helpful for certain debugging scenarios. ipython_shell: If ``True`` keyboard interrupts (Ctrl-C) won't exit the process but only pause training and enter an IPython shell. Additionally, errors during training (except C-level segfaults etc.) won't crash the whole training process, but drop to an IPython shell so errors can be inspected with access to the current training state. out_channels: Optionally specifies the total number of different target classes for classification tasks. If this is not set manually, the ``Trainer`` checks if the ``train_dataset`` provides this value. If available, ``self.out_channels`` is set to ``self.train_dataset.out_channels``. Otherwise, it is set to ``None``. The ``out_channels`` attribute is used for plotting purposes and is not strictly required for training. sample_plotting_handler: Function that receives training and validation samples and is responsible for visualizing them by e.g. plotting them to tensorboard and/or writing them to files. It is called once after each training epoch and once after each validation run. If ``None``, a tensorboard-based default handler is used that works for most classification scenarios and for 3-channel regression. preview_plotting_handler: Function that is responsible for producing previews and visualizing/plotting/logging them. It is called once each ``preview_interval`` epochs. If ``None``, a tensorboard-based default handler is used that works for most classification scenarios. mixed_precision: If ``True``, enable Automated Mixed Precision training powered by https://github.com/NVIDIA/apex to reduce memory usage and (if a GPU with Tensor Cores is used) make training much faster. This is currently experimental and might cause instabilities. tqdm_kwargs: Extra arguments to be passed to tqdm progress bars. For example, to disable tqdm outputs completely, pass ``tqdm_kwargs={'disable': True}``. """ tb: tensorboardX.SummaryWriter terminate: bool step: int epoch: int train_loader: torch.utils.data.DataLoader valid_loader: torch.utils.data.DataLoader exp_name: str save_path: str # Full path to where training files are stored out_channels: Optional[int] # Number of channels of the network outputs def __init__( self, model: torch.nn.Module, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, save_root: str, train_dataset: torch.utils.data.Dataset, valid_dataset: Optional[torch.utils.data.Dataset] = None, unlabeled_dataset: Optional[torch.utils.data.Dataset] = None, valid_metrics: Optional[Dict] = None, ss_criterion: Optional[torch.nn.Module] = None, preview_batch: Optional[torch.Tensor] = None, knossos_preview_config: Optional[Dict[str, str]] = None, preview_interval: int = 5, inference_kwargs: Optional[Dict[str, Any]] = None, hparams: Optional[Dict[str, Any]] = None, extra_save_steps: Sequence[int] = (), exp_name: Optional[str] = None, example_input: Optional[torch.Tensor] = None, enable_save_trace: bool = False, save_jit: Optional[str] = None, batch_size: int = 1, num_workers: int = 0, schedulers: Optional[Dict[Any, Any]] = None, overlay_alpha: float = 0.4, enable_videos: bool = False, enable_tensorboard: bool = True, tensorboard_root_path: Optional[str] = None, ignore_errors: bool = False, ipython_shell: bool = False, out_channels: Optional[int] = None, sample_plotting_handler: Optional[Callable] = None, preview_plotting_handler: Optional[Callable] = None, mixed_precision: bool = False, tqdm_kwargs: Optional[Dict] = None ): inference_kwargs = {} if inference_kwargs is None else inference_kwargs if preview_batch is not None and ( 'tile_shape' not in inference_kwargs or ( 'overlap_shape' not in inference_kwargs and 'offset' not in inference_kwargs)): raise ValueError( 'If preview_batch is set, you will also need to specify ' 'tile_shape and overlap_shape or offset in inference_kwargs!' ) if knossos_preview_config is not None: if preview_batch is not None: raise ValueError('If you set a preview_knossos_source, you cannot also set a preview batch.') preview_batch = create_preview_batch_from_knossos(knossos_preview_config) if enable_save_trace: logger.warning('enable_save_trace is deprecated. Please use the save_jit option instead.') assert save_jit in [None, 'trace'] save_jit = 'trace' # Ensure that all nn.Modules are on the right device model.to(device) if isinstance(criterion, torch.nn.Module): criterion.to(device) if isinstance(ss_criterion, torch.nn.Module): ss_criterion.to(device) self.ignore_errors = ignore_errors self.ipython_shell = ipython_shell self.device = device self.model = model self.criterion = criterion self.optimizer = optimizer self.train_dataset = train_dataset self.valid_dataset = valid_dataset self.unlabeled_dataset = unlabeled_dataset self.valid_metrics = valid_metrics self.ss_criterion = ss_criterion self.preview_batch = preview_batch self.knossos_preview_config = knossos_preview_config self.preview_interval = preview_interval self.inference_kwargs = inference_kwargs self.extra_save_steps = extra_save_steps self.overlay_alpha = overlay_alpha self.save_root = os.path.expanduser(save_root) self.example_input = example_input self.save_jit = save_jit self.batch_size = batch_size self.num_workers = num_workers self.sample_plotting_handler = sample_plotting_handler self.preview_plotting_handler = preview_plotting_handler self.mixed_precision = mixed_precision self.tqdm_kwargs = {} if tqdm_kwargs is None else tqdm_kwargs self._tracker = HistoryTracker() self._timer = Timer() self._first_plot = True self._shell_info = dedent(""" Entering IPython training shell. To continue, hit Ctrl-D twice. To terminate, set self.terminate = True and then hit Ctrl-D twice. """).strip() self.inference_kwargs.setdefault('batch_size', 1) self.inference_kwargs.setdefault('verbose', True) self.inference_kwargs.setdefault('apply_softmax', True) if self.unlabeled_dataset is not None and self.ss_criterion is None: raise ValueError('If an unlabeled_dataset is supplied, you must also set ss_criterion.') if hparams is None: hparams = {} else: for k, v in hparams.items(): if isinstance(v, (tuple, list)): # Convert to str because tensorboardX doesn't support # tuples and lists in add_hparams() hparams[k] = str(v) self.hparams = hparams self.scaler = amp.GradScaler(enabled=self.mixed_precision) if exp_name is None: # Auto-generate a name based on model name and ISO timestamp timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S') exp_name = model.__class__.__name__ + '__' + timestamp self.exp_name = exp_name self.save_path = os.path.join(save_root, exp_name) if os.path.isdir(self.save_path): raise RuntimeError( f'{self.save_path} already exists.\nPlease choose a ' 'different combination of save_root and exp_name.' ) os.makedirs(self.save_path) _change_log_file_to(f'{self.save_path}/elektronn3.log') logger.info(f'Writing files to save_path {self.save_path}/\n') self.terminate = False self.step = 0 self.epoch = 0 if schedulers is None: schedulers = {'lr': StepLR(optimizer, 1000, 1)} # No-op scheduler self.schedulers = schedulers self.__lr_closetozero_alreadytriggered = False # Used in periodic scheduler handling self._lr_nhood = deque(maxlen=3) # Keeps track of the last, current and next learning rate self.out_channels = out_channels self.max_plot_id = None try: self.max_plot_id = max(self.out_channels, self.criterion.ignore_index + 1) except AttributeError: # no ignore_idx self.max_plot_id = self.out_channels except TypeError: # no out_channels self.max_plot_id = None if enable_videos: try: import moviepy except: logger.warning('moviepy is not installed. Disabling video logs.') enable_videos = False self.enable_videos = enable_videos self.tb = None # Tensorboard handler if enable_tensorboard: if self.sample_plotting_handler is None: self.sample_plotting_handler = handlers._tb_log_sample_images if self.preview_plotting_handler is None: self.preview_plotting_handler = handlers._tb_log_preview if tensorboard_root_path is None: tb_path = self.save_path else: tensorboard_root_path = os.path.expanduser(tensorboard_root_path) tb_path = os.path.join(tensorboard_root_path, self.exp_name) os.makedirs(tb_path, exist_ok=True) self.tb = tensorboardX.SummaryWriter(logdir=tb_path, flush_secs=20) if self.hparams: self.tb.add_hparams(hparam_dict=self.hparams, metric_dict={}) self.train_loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, timeout=60 if self.num_workers > 0 else 0, worker_init_fn=_worker_init_fn ) if valid_dataset is not None: self.valid_loader = DataLoader( self.valid_dataset, self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, worker_init_fn=_worker_init_fn ) if self.unlabeled_dataset is not None: self.unlabeled_loader = DataLoader( self.unlabeled_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, timeout=60 if self.num_workers > 0 else 0, worker_init_fn=_worker_init_fn ) self.best_val_loss = np.inf # Best recorded validation loss self.best_tr_loss = np.inf self.valid_metrics = {} if valid_metrics is None else valid_metrics
[docs] def run(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None: """Train the network for ``max_steps`` steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.""" self.start_time = datetime.datetime.now() self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime) self._save_model(suffix='_initial', verbose=False) self._lr_nhood.clear() self._lr_nhood.append(self.optimizer.param_groups[0]['lr']) # LR of the first training step while not self.terminate: try: stats, misc, tr_sample_images = self._train(max_steps, max_runtime) self.epoch += 1 if self.valid_dataset is None: stats['val_loss'] = nan val_sample_images = None else: valid_stats, val_sample_images = self._validate() stats.update(valid_stats) # Log to stdout and text log file self._log_basic(stats, misc) # Render visualizations and log to tensorboard self._log_to_tensorboard(stats, misc, tr_sample_images, val_sample_images) # Legacy non-tensorboard logging to files self._log_to_history_tracker(stats, misc) # Save trained model state self._save_model(val_loss=stats['val_loss'], verbose=False) # Not verbose because it can get spammy. # TODO: Support other metrics for determining what's the "best" model? if stats['val_loss'] < self.best_val_loss: self.best_val_loss = stats['val_loss'] self._save_model(suffix='_best', val_loss=stats['val_loss']) except KeyboardInterrupt: if self.ipython_shell: IPython.embed(header=self._shell_info) else: break if self.terminate: break except Exception as e: logger.exception('Unhandled exception during training:') if self.ignore_errors: # Just print the traceback and try to carry on with training. # This can go wrong in unexpected ways, so don't leave the training unattended. pass elif self.ipython_shell: print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") IPython.embed(header=self._shell_info) if self.terminate: break else: raise e self._save_model(suffix='_final') if self.tb is not None: self.tb.close() # Ensure that everything is flushed
def _train_step(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: """Core training step on self.device""" inp = batch['inp'] target = batch.get('target') target_class = batch.get('class') # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU) dinp = inp.to(self.device, non_blocking=True) dtarget = target.to(self.device, non_blocking=True) if target is not None else None dtarget_class = target_class.to(self.device, non_blocking=True) if target_class is not None else None # forward pass with amp.autocast(enabled=self.mixed_precision): dout = self.model(dinp) if dtarget_class is not None: dloss = self.criterion(dout, dtarget, dtarget_class) else: dloss = self.criterion(dout, dtarget) unlabeled = batch.get('unlabeled') if unlabeled is not None: # Add a simple consistency loss u_inp = unlabeled['inp'] du_inp = u_inp.to(self.device, non_blocking=True) with amp.autocast(enabled=self.mixed_precision): du_loss = self.ss_criterion(du_inp) dloss += du_loss self.tb.add_scalar('stats/tr_uloss', float(du_loss), global_step=self.step) if torch.isnan(dloss): logger.error('NaN loss detected! Aborting training.') raise NaNException # update step self.scaler.scale(dloss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad(set_to_none=True) return dloss, dout def _train(self, max_steps, max_runtime): """Train for one epoch or until max_steps or max_runtime is reached""" self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, Union[float, List[float]]] = {stat: [] for stat in ['tr_loss']} # Other scalars to be logged misc: Dict[str, Union[float, List[float]]] = {misc: [] for misc in ['mean_target']} # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, np.ndarray] = {} running_vx_size = 0 # Counts input sizes (number of pixels/voxels) of training batches timer = Timer() batch_iter = tqdm( self.train_loader, 'Training', total=len(self.train_loader), dynamic_ncols=True, **self.tqdm_kwargs ) unlabeled_iter = None if self.unlabeled_dataset is None else iter(self.unlabeled_loader) for i, batch in enumerate(batch_iter): if self.step in self.extra_save_steps: self._save_model(f'_step{self.step}', verbose=True) if unlabeled_iter is not None: batch['unlabeled'] = next(unlabeled_iter) dloss, dout = self._train_step(batch) with torch.no_grad(): loss = float(dloss) target = batch.get('target') mean_target = float(target.to(torch.float32).mean()) if target is not None else 0. misc['mean_target'].append(mean_target) stats['tr_loss'].append(loss) batch_iter.set_description(f'Training (loss {loss:.4f})') self._tracker.update_timeline([self._timer.t_passed, loss, mean_target]) # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr() misc['learning_rate'] = self.optimizer.param_groups[0]['lr'] # LR for the this iteration self._scheduler_step(loss) running_vx_size += batch['inp'].numel() self._incr_step(max_runtime, max_steps) if i == len(self.train_loader) - 1 or self.terminate: # Last step in this epoch or in the whole training # Preserve last training batch and network output for later visualization images['inp'] = batch['inp'].numpy() images['fname'] = batch.get('fname') if 'target' in batch: images['target'] = batch['target'].numpy() if 'unlabeled' in batch: images['unlabeled'] = batch['unlabeled'] images['out'] = dout.detach().cpu().numpy() self._put_current_attention_maps_into(images) if self.terminate: break stats['tr_loss_std'] = np.std(stats['tr_loss']) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx return stats, misc, images def _put_current_attention_maps_into(self, images): if getattr(self.model, 'attention', None): for i in range(len(self.model.up_convs)): att = self.model.up_convs[i].att[0][0].detach().cpu().numpy() if att.ndim == 3: att = att[att.shape[0] // 2] images[f'att{i}'] = att def _incr_step(self, max_runtime, max_steps): """Increment the current training step counter""" self.step += 1 if self.step >= max_steps: logger.info(f'max_steps ({max_steps}) exceeded. Terminating...') self.terminate = True if datetime.datetime.now() >= self.end_time: logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...') self.terminate = True def _scheduler_step(self, loss): """Update schedules""" for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if 'metrics' in inspect.signature(sched.step).parameters: sched.step(metrics=loss) else: sched.step() # Append LR of the next iteration (after sched.step()) for local LR minima detection self._lr_nhood.append(self.optimizer.param_groups[0]['lr']) self._handle_lr() def _handle_lr(self) -> None: r"""Handle quasi-periodic learning rate schedulers that lower the learning rate to local minima but then ramp it up again (Cosine Annealing, SGDR, Cyclical LRs etc.). Model saving is triggered when a local minimum of learning rates is detected. For the motivation of this behavior, see https://arxiv.org/abs/1704.00109. The saved models can be used to build an ensemble. Local minima are found by checking for the simple criterion :math:`\lr_{t-1}` > \lr{t} < lr{t+1}`. If an SWA (Stochastic Weight Averaging) optimizer is detected, the SWA algorithm is performed (see https://arxiv.org/abs/1803.05407) and the resulting model is also saved, marked by the "_swa" file name suffix. .. note:: The saved SWA model performs batch norm statistics correction only on a limited number of batches from the ``self.train_loader`` (currently hardcoded to 10), so if the model uses batch normalization with running statistics and you suspect that this amount of batches won't be representative enough for your input data distribution, you might want to ensure a good estimate yourself by running :py:meth:`elektronn3.trainer.SWA.bn_update()` on the model with a larger number of input batches after loading the model for inference. """ if len(self._lr_nhood) < 3: return # Can't get lrs, but at this early stage it's also not relevant last_lr = self._lr_nhood[-3] curr_lr = self._lr_nhood[-2] next_lr = self._lr_nhood[-1] if last_lr > curr_lr < next_lr: logger.info( f'Local learning rate minimum {curr_lr:.2e} detected at step ' f'{self.step}. Saving model...') self._save_model(suffix=f'_minlr_step{self.step}') # Handle Stochastic Weight Averaging optimizer if SWA is used if isinstance(self.optimizer, SWA): # TODO: Make bn_update configurable (esp. number of batches) self.optimizer.update_swa() # Put current model params into SWA buffer self.optimizer.swap_swa_sgd() # Perform SWA and write results into model params has_bn = any(isinstance(m, torch.nn.modules.batchnorm._BatchNorm) for m in self.model.modules()) if has_bn: # Perform batch norm correction try: max_bn_corr_batches = 10 # Batches to use to correct SWA batchnorm stats # We're assuming here that len(self.train_loader), which is an upper bound for # len(swa_loader), is sufficient for a good stat estimation swa_loader = islice(self.train_loader, max_bn_corr_batches) # This may be expensive (comparable to validation computations) SWA.bn_update(swa_loader, self.model, device=self.device) self._save_model(suffix='_swa', verbose=False) except: logger.exception( 'SWA helper bn_update has failed. SWA model will be saved with incorrect ' 'batchnorm statistics. Please make sure to manually correct the BN stats ' 'before deploying the model.' ) self._save_model(suffix='_swa_todo_batchnorm_corr', verbose=False) else: # No batch norm -> save model directly self._save_model(suffix='_swa', verbose=False) self.optimizer.swap_swa_sgd() # Swap back model to the original state before SWA @torch.no_grad() def _validate(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]: self.model.eval() # Set dropout and batchnorm to eval mode val_loss = [] outs = [] targets = [] stats = {name: [] for name in self.valid_metrics.keys()} batch_iter = tqdm( enumerate(self.valid_loader), 'Validating', total=len(self.valid_loader), dynamic_ncols=True, **self.tqdm_kwargs ) for i, batch in batch_iter: # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU) inp = batch['inp'] target = batch.get('target') target_class = batch.get('class') dinp = inp.to(self.device, non_blocking=True) dtarget = target.to(self.device, non_blocking=True) if target is not None else None dtarget_class = target_class.to(self.device, non_blocking=True) if target_class is not None else None with amp.autocast(enabled=self.mixed_precision): dout = self.model(dinp) if dtarget is None: # Use self-supervised unary loss function val_loss.append(self.ss_criterion(dout).item()) elif dtarget_class is not None: val_loss.append(self.criterion(dout, dtarget, dtarget_class).item()) else: val_loss.append(self.criterion(dout, dtarget).item()) out = dout.detach().cpu() outs.append(out) targets.append(target) images = { 'inp': inp.numpy(), 'out': out.numpy(), 'target': None if target is None else target.numpy(), 'fname': batch.get('fname'), } self._put_current_attention_maps_into(images) stats['val_loss'] = np.mean(val_loss) stats['val_loss_std'] = np.std(val_loss) for name, evaluator in self.valid_metrics.items(): mvals = [evaluator(target, out) for target, out in zip(targets, outs)] if np.all(np.isnan(mvals)): stats[name] = np.nan else: stats[name] = np.nanmean(mvals) # # This code is currently commented out because it's quite slow. TODO: Speed up by computing softmax on GPU above # # Plot per-class PR curves if a supported classification scenario is detected. # if out.ndim == target.ndim + 1 and self.inference_kwargs.get('apply_softmax'): # softmax_outs = torch.stack(outs).softmax(2) # Apply softmax in dim=2 because of additional stack dim # for c in range(out.shape[1]): # self.tb.add_pr_curve( # f'pr_c{c}', # labels=torch.stack(targets), # predictions=torch.stack([so[:, c] for so in softmax_outs]), # global_step=self.step # ) self.model.train() # Reset model to training mode return stats, images # TODO: Instead of using specific keys like val_loss, enable passing info as an # extra dict whose contents will be added to the state_dict def _save_model( self, suffix: str = '', unwrap_parallel: bool = True, verbose: bool = True, val_loss=np.nan ) -> None: """Save/serialize trained model state to files. Writes the following files in the ``self.save_path``: - ``state_dicts.pth`` contains the a dict that holds the ``state_dict`` of the trained model, the ``state_dict`` of the optimizer and some meta information (global step, epoch, best validation loss) The included parameters can be read and used to overwrite another model's ``state_dict``. The model code (architecture) itself is not included in this file. - ``model.pt`` contains a pickled version of the complete model, including the trained weights. You can simply ``model = torch.load('model.pt')`` to obtain the full model and its training state. This will not work if the source code relevant to de- serializing the model object has changed! If this is is the case, you will need to use the ``state_dict.pth`` to extract parameters and manually load them into a model. - ``model.pts`` contains the model in the ``torch.jit`` ScriptModule serialization format. If ``model`` is not already a ``ScriptModule`` and ``self.save_jit`` is not ``None``, a ScriptModule form of the ``model`` will be created on demand. Args: suffix: If defined, this string will be added before the file extensions of the respective files mentioned above. unwrap_parallel: If ``True`` (default) and the model uses a parallel module wrapper like ``torch.nn.DataParallel``, this is automatically detected and the wrapped model is saved directly to make later deserialization easier. This can be disabled by setting ``unwrap_parallel=False``. verbose: If ``True`` (default), log infos about saved models at log-level "INFO" (which appears in stdout). Else, only silently log with log-level "DEBUG". val_loss: Stores the validation loss (default value if not supplied: NaN) """ log = logger.info if verbose else logger.debug model = self.model model_trainmode = model.training # We do this awkard check because there are too many different # parallel wrappers in PyTorch and some of them have changed names # in different releases (DataParallel, DistributedDataParallel{,CPU}). is_wrapped = ( hasattr(model, 'module') and 'parallel' in str(type(model)).lower() and isinstance(model.module, torch.nn.Module) ) if is_wrapped and unwrap_parallel: # If a parallel wrapper was used, the only thing we should save # is the model.module, which contains the actual model and params. # If we saved the wrapped module directly, deserialization would # get unnecessarily difficult. model = model.module state_dict_path = os.path.join(self.save_path, f'state_dict{suffix}.pth') model_path = os.path.join(self.save_path, f'model{suffix}.pt') try: lr_sched_state = self.schedulers['lr'].state_dict() except: # No valid scheduler in use lr_sched_state = None info = { 'global_step': self.step, 'epoch': self.epoch, 'best_val_loss': self.best_val_loss, 'val_loss': val_loss, 'inference_kwargs': self.inference_kwargs, 'elektronn3.__version__': elektronn3.__version__, 'env_info': collect_env.get_pretty_env_info() } # Make sure everything is a string (if inference_kwargs contains a # transform object, it may not be picklable) info = {k: str(v) for k, v in info.items()} torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'lr_sched_state_dict': lr_sched_state, 'scaler_state_dict': self.scaler.state_dict(), 'info': info }, state_dict_path) log(f'Saved state_dict as {state_dict_path}') pts_model_path = f'{model_path}s' try: # Try saving directly as an uncompiled nn.Module torch.save(model, model_path) log(f'Saved model as {model_path}') if self.save_jit == 'script': # Compile directly for serialization jitmodel = torch.jit.script(model) elif self.save_jit == 'trace': # Trace and serialize the model in eval mode if self.example_input is None: raise ValueError('If save_jit="trace", example_input needs to be specified.') with warnings.catch_warnings(): # It's enough to be warned once during initial tracing warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) jitmodel = torch.jit.trace(model.eval(), self.example_input.to(self.device)) if self.save_jit is not None: # Save jit model, either from script or trace jitmodel.save(pts_model_path) log(f'Saved jitted model ({self.save_jit}) as {pts_model_path}') except (TypeError, PickleError) as exc: # If model is already a ScriptModule, it can't be saved with torch.save() # Use ScriptModule.save() instead in this case. # Using the file extension '.pts' to show it's a ScriptModule. if isinstance(model, torch.jit.ScriptModule): model_path += 's' model.save(pts_model_path) log(f'Saved jitted model as {pts_model_path}') else: raise exc finally: # Reset training state to the one it had before this function call, # because it could have changed with the model.eval() call above. model.training = model_trainmode if os.path.isfile(pts_model_path): with zipfile.ZipFile(pts_model_path, 'a', compression=zipfile.ZIP_DEFLATED) as zfile: infostr = pprint.pformat(info, indent=2, width=120) zfile.writestr('info.txt', infostr) def _log_basic(self, stats, misc): """Log to stdout and text log file""" tr_loss = np.mean(stats['tr_loss']) val_loss = np.mean(stats['val_loss']) lr = misc['learning_rate'] tr_speed = misc['tr_speed'] tr_speed_vx = misc['tr_speed_vx'] t = pretty_string_time(self._timer.t_passed) text = f'step={self.step:06d}, tr_loss={tr_loss:.3f}, val_loss={val_loss:.3f}, ' text += f'lr={lr:.2e}, {tr_speed:.2f} it/s, {tr_speed_vx:.2f} MVx/s, {t}' logger.info(text) def _log_to_tensorboard( self, stats: Dict, misc: Dict, tr_images: Dict, val_images: Optional[Dict] = None, file_stats: Optional[Dict] = None, ) -> None: """Create visualizations, make preview predictions, log and plot to tensorboard""" if self.tb: try: self._tb_log_scalars(stats, 'stats') self._tb_log_scalars(misc, 'misc') if self.preview_batch is not None: if self.epoch % self.preview_interval == 0 or self.epoch == 1: # TODO: Also save preview inference results in a (3D) HDF5 file self.preview_plotting_handler(self) self.sample_plotting_handler(self, tr_images, group='tr_samples') if val_images is not None: self.sample_plotting_handler(self, val_images, group='val_samples') if file_stats is not None: self._tb_log_scalars(file_stats, 'file_stats') self._tb_log_histograms() except Exception: logger.exception('Error occured while logging to tensorboard:') def _log_to_history_tracker(self, stats: Dict, misc: Dict) -> None: """Update history tracker and plot stats (kind of made obsolete by tensorboard)""" # TODO: Decide what to do with this, now that most things are already in tensorboard. if self._tracker.history.length > 0: tr_loss_gain = self._tracker.history[-1][2] - np.mean(stats['tr_loss']) else: tr_loss_gain = 0 if not stats.get('tr_accuracy'): tr_accuracy = nan else: tr_accuracy = np.nanmean(stats['tr_accuracy']) val_accuracy = stats.get('val_accuracy', nan) self._tracker.update_history([ self.step, self._timer.t_passed, np.mean(stats['tr_loss']), np.mean(stats['val_loss']), tr_loss_gain, tr_accuracy, val_accuracy, misc['learning_rate'], 0, 0 ]) # Plot tracker stats to pngs in save_path self._tracker.plot(self.save_path) def _tb_log_scalars( self, scalars: Dict[str, float], tag: str = 'default' ) -> None: for key, value in scalars.items(): if isinstance(value, (list, tuple, np.ndarray)): for i in range(len(value)): if not np.isnan(value[i]): self.tb.add_scalar(f'{tag}/{key}', value[i], self.step - len(value) + i) elif not np.isnan(value): self.tb.add_scalar(f'{tag}/{key}', value, self.step) def _tb_log_histograms(self) -> None: """Log histograms of model parameters and their current gradients. Make sure to run this between ``backward()`` and ``zero_grad()``, because otherwise gradient histograms will only consist of zeros. """ for name, param in self.model.named_parameters(): self.tb.add_histogram(f'param/{name}', param, self.step) grad = param.grad if param.grad is not None else torch.tensor(0) self.tb.add_histogram(f'grad/{name}', grad, self.step) def _preview_inference( self, inp: np.ndarray, inference_kwargs: Dict[str, Any], ) -> torch.Tensor: if self.out_channels is None: raise RuntimeError('Can\'t do preview prediction if Trainer.out_channels is not set.') out_shape = (self.out_channels, *inp.shape[2:]) predictor = Predictor( model=self.model, device=self.device, out_shape=out_shape, **inference_kwargs, ) out = predictor.predict(inp) return out
[docs] class Backup: """ Backup class for archiving training script, src folder and environment info. Should be used for any future archiving needs. Args: script_path: The path to the training script. Eg. train_unet_neurodata.py save_path: The path where the information is archived. extra_content: Dictionary of {filename: content} entries, where content is a string that should be written to a file with the specified name. """ def __init__(self, script_path, save_path, extra_content=None): self.script_path = script_path self.save_path = save_path self.extra_content = extra_content
[docs] def archive_backup(self): """Archiving the source folder, the training script and environment info. The training script is saved with the prefix "0-" to distinguish from regular scripts. Environment information equivalent to the output of ``python -m torch.utils.collect_env`` is saved in a file named "env_info.txt". """ # Archiving the Training script shutil.copyfile(self.script_path, self.save_path + '/0-' + os.path.basename(self.script_path)) os.chmod(self.save_path + '/0-' + os.path.basename(self.script_path), 0o755) # Archiving the src folder pkg_path = os.path.dirname(arch_src) backup_path = os.path.join(self.save_path, 'src_backup') shutil.make_archive(backup_path, 'gztar', pkg_path) # Archiving the Environment Info env_info = collect_env.get_pretty_env_info() with open(self.save_path + '/env_info.txt', 'w') as f: f.write(env_info) if self.extra_content is not None: for fname, content in self.extra_content.items(): with open(f'{self.save_path}/{fname}', 'w') as f: f.write(content)
[docs] def findcudatensors() -> Tuple[int, List[torch.Tensor]]: """Find currently living tensors that are allocated on cuda device memory. This can be used for debugging memory leaks: If ``findcudatensors()[0]`` grows unexpectedly between GPU computations, you can look at the returned ``tensors`` list to find out what tensors are currently allocated, for example ``print([x.shape for x in findcudatensors()[1])``. Returns a tuple of - total memory usage of found tensors in MiB - a list of all of those tensors, ordered by size.""" tensors = [] for obj in gc.get_objects(): try: if torch.is_tensor(obj) \ and obj.device.type == 'cuda' \ and not isinstance(obj, torch.nn.Parameter): # Exclude model params tensors.append(obj) except: pass tensors.sort(key=lambda x: x.numel()) total_mib = sum(x.numel() * 32 for x in tensors) / 1024**2 # Assuming float32 return total_mib, tensors