Source code for

from typing import Dict, Any, Tuple, Union, List

import numpy as np
import torch
from tqdm import tqdm

from import Trainer, handlers
from import Timer
from import logger, NaNException

[docs]class TripletTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def noop(*args, **kwargs): pass self.preview_plotting_handler = noop # TODO self.sample_plotting_handler = handlers._tb_log_sample_images_all_img def _train_step_triplet(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Core training step for triplet loss on self.device""" # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU) danchor = batch['anchor'].to(self.device, non_blocking=True) dpos = batch['pos'].to(self.device, non_blocking=True) dneg = batch['neg'].to(self.device, non_blocking=True) # forward pass danc_out = self.model(danchor) dpos_out = self.model(dpos) dneg_out = self.model(dneg) dloss = self.criterion(danc_out, dpos_out, dneg_out) if torch.isnan(dloss): logger.error('NaN loss detected! Aborting training.') raise NaNException # update step self.optimizer.zero_grad() dloss.backward() self.optimizer.step() return dloss, {'anchor_out': danc_out, 'pos_out': dpos_out, 'neg_out': dneg_out} def _train(self, max_steps, max_runtime): """Train for one epoch or until max_steps or max_runtime is reached""" self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, Union[float, List[float]]] = {stat: [] for stat in ['tr_loss']} # Other scalars to be logged misc: Dict[str, Union[float, List[float]]] = {misc: [] for misc in ['mean_target']} # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, np.ndarray] = {} running_vx_size = 0 # Counts input sizes (number of pixels/voxels) of training batches timer = Timer() batch_iter = tqdm( self.train_loader, 'Training', total=len(self.train_loader), dynamic_ncols=True, **self.tqdm_kwargs ) for i, batch in enumerate(batch_iter): if self.step in self.extra_save_steps: self._save_model(f'_step{self.step}', verbose=True) dloss, dout_imgs = self._train_step_triplet(batch) with torch.no_grad(): loss = float(dloss) mean_target = 0. # Dummy value misc['mean_target'].append(mean_target) stats['tr_loss'].append(loss) batch_iter.set_description(f'Training (loss {loss:.4f})') self._tracker.update_timeline([self._timer.t_passed, loss, mean_target]) # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr() misc['learning_rate'] = self.optimizer.param_groups[0]['lr'] # LR for the this iteration self._scheduler_step(loss) running_vx_size += batch['anchor'].numel() self._incr_step(max_runtime, max_steps) if i == len(self.train_loader) - 1 or self.terminate: # Last step in this epoch or in the whole training # Preserve last training batch and network output for later visualization for key, img in batch.items(): if isinstance(img, torch.Tensor): img = img.detach().cpu().numpy() images[key] = img self._put_current_attention_maps_into(images) # TODO: The plotting handler abstraction is inadequate here. Figure out how # we can handle plotting cleanly in one place. # Outputs are visualized here, while inputs are visualized in the plotting handler # which is called in _run()... for name, img in dout_imgs.items(): img = img.detach()[0].cpu().numpy() # select first item of batch for c in range(img.shape[0]): if img.ndim == 4: # 3D data img = img[:, img.shape[0] // 2] # take center slice of depth dim -> 2D self.tb.add_figure( f'tr_samples/{name}_c{c}', handlers.plot_image(img[c], cmap='gray'), global_step=self.step ) if self.terminate: break stats['tr_loss_std'] = np.std(stats['tr_loss']) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx return stats, misc, images def _validate(self) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]: raise NotImplementedError