Source code for elektronn3.training.noise2void

"""
Trainer code for 2D and 3D Noise2Void (https://arxiv.org/abs/1811.10980)
Adapted from https://github.com/juglab/pn2v/blob/master/pn2v/training.py,
ported from NumPy to PyTorch and generalized to support 3D.
"""

from typing import Callable

import torch
from torch import nn
from torch.cuda import amp
import numpy as np
import itertools

from scipy.ndimage.filters import gaussian_filter
from tqdm import tqdm

from elektronn3.training.trainer import Trainer, NaNException
from elektronn3.modules.loss import MaskedMSELoss

import logging
logger = logging.getLogger('elektronn3log')


@torch.no_grad()
def get_stratified_coords(ratio, shape):
    """
    Produce a list of approx. ``num_pix`` random coordinates, sampled from
    ``shape`` using startified sampling. Supports n-dimensional shapes.
    """
    # total_num = torch.prod(shape).to(torch.float32)
    # sample_num = total_num * ratio
    ratio = torch.as_tensor(ratio)
    ndim = len(shape)
    shape = torch.as_tensor(shape, dtype=torch.int32)
    box_size = int(torch.round(torch.sqrt(1. / ratio)))
    coords = []
    box_counts = torch.ceil(shape.float() / box_size).int()
    for steps in itertools.product(*[range(bc) for bc in box_counts]):
        steps = torch.as_tensor(steps, dtype=torch.int32)
        co = torch.randint(0, box_size, (ndim,)) + box_size * steps
        if torch.all(co < shape):
            coords.append(co)
    if not coords:
        raise ValueError(f'ratio {ratio:.1e} is too close to zero. Choose a higher value.')
    coords = torch.stack(coords)
    return coords


# TODO: Is the hardcoded small ROI size sufficient?
@torch.no_grad()
def prepare_sample(img, ratio=1e-3, channels=None):
    """Prepare binary mask and target image for Noise2Void from a given image"""
    ndim = img.ndim - 2  # Subtract (N, C) dims
    if channels is None:
        channels = range(img.shape[1])
    inp = img.clone()
    target = img
    mask = torch.zeros_like(img)
    for n, c in itertools.product(range(img.shape[0]), channels):
        hotcoords = get_stratified_coords(ratio, img[n, c].shape)
        maxsh = np.array(img[n, c].shape) - 1
        for hc in hotcoords:
            roimin = np.clip(hc - 2, 0, None)
            roimax = np.clip(hc + 3, None, maxsh)
            roi = img[n, c, roimin[0]:roimax[0], roimin[1]:roimax[1]]
            if ndim == 3:
                roi = roi[..., roimin[2]:roimax[2]]  # slice 3rd dim if input is 3D
            rc = np.full((ndim,), 2)
            while np.all(rc == 2):
                rc = np.random.randint(0, roi.shape, (ndim,))
            repl = roi[tuple(rc)]  # Select point at rc in current ROI for replacement
            inp[(n, c, *hc)] = repl
            mask[(n, c, *hc)] = 1.0

    return inp, target, mask


[docs] class Noise2VoidTrainer(Trainer): """Trainer subclass with custom training and validation code for Noise2Void training. Noise2Void is applied by default, but it can also be replaced or accompanied by additive gaussian noise and gaussian blurring (see args below). Args: model: PyTorch model (``nn.Module``) that shall be trained. criterion: Training criterion. If ``n2v_ratio > 0``, it should expect 3 arguments, the third being the Noise2Void mask. Per default, a masked MSE loss is used. *args: *Other positional args. See signature of :py:class:`elektronn3.training.Trainer`* n2v_ratio: Ratio of pixels to be manipulated and masked in each image according to the Noise2Void algorithm. If it is set to a value <= 0, Noise2Void is disabled. agn_max_std: Maximum std (sigma parameter) for additive gaussian noise that is optionally applied to the input image. Standard deviations are sampled from a uniform distribution that ranges between 0 and ``agn_max_std``. If it is set to a value <= 0, additive gaussian noise is disabled. gblur_sigma: Sigma parameter for gaussian blurring that is optionally applied to the input image. If it is set to a value <= 0, gaussian blurring is disabled. **kwargs: Other keyword args. See signature of :py:class:`elektronn3.training.Trainer` """ def __init__( self, model: torch.nn.Module, criterion: torch.nn.Module = MaskedMSELoss(), *args, n2v_ratio: float = 1e-3, agn_max_std: float = 0, gblur_sigma: float = 0, **kwargs ): super().__init__(model, criterion, *args, **kwargs) self.n2v_ratio = n2v_ratio self.agn_max_std = agn_max_std self.gblur_sigma = gblur_sigma def _train_step(self, batch): # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU) dimg = batch['inp'].to(self.device, non_blocking=True) if self.n2v_ratio > 0: dinp, dtarget, dmask = prepare_sample(dimg, ratio=self.n2v_ratio) else: dinp = dimg.clone() dtarget = dimg dmask = None # Apply additive gaussian noise if self.agn_max_std > 0: agn_std = np.random.rand() * self.agn_max_std # stds from range [0, agn_max_std] dinp.add_(torch.randn_like(dinp).mul_(agn_std)) # Apply gaussian blurring if self.gblur_sigma > 0: dinp = dinp.cpu().numpy() for n, c in itertools.product(range(dinp.shape[0]), range(dinp.shape[1])): dinp[n, c] = gaussian_filter(dinp[n, c], sigma=self.gblur_sigma) dinp = torch.as_tensor(dinp).to(self.device).float() # forward pass with amp.autocast(enabled=self.mixed_precision): dout = self.model(dinp) if dmask is None: dloss = self.criterion(dout, dtarget) else: dloss = self.criterion(dout, dtarget, dmask) if torch.isnan(dloss): logger.error('NaN loss detected! Aborting training.') raise NaNException # update step self.scaler.scale(dloss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad(set_to_none=True) return dloss, dout @torch.no_grad() def _validate(self): self.model.eval() # Set dropout and batchnorm to eval mode val_loss = [] outs = [] targets = [] stats = {name: [] for name in self.valid_metrics.keys()} batch_iter = tqdm( enumerate(self.valid_loader), 'Validating', total=len(self.valid_loader), dynamic_ncols=True, **self.tqdm_kwargs ) for i, batch in batch_iter: dimg = batch['inp'].to(self.device, non_blocking=True) if self.n2v_ratio > 0: dinp, dtarget, dmask = prepare_sample(dimg, ratio=self.n2v_ratio) else: dinp = dimg.clone() dtarget = dimg dmask = None # Apply additive gaussian noise if self.agn_max_std > 0: agn_std = np.random.rand() * self.agn_max_std # stds from range [0, agn_max_std] dinp.add_(torch.randn_like(dinp).mul_(agn_std)) # Apply gaussian blurring if self.gblur_sigma > 0: dinp = dinp.cpu().numpy() for n, c in itertools.product(range(dinp.shape[0]), range(dinp.shape[1])): dinp[n, c] = gaussian_filter(dinp[n, c], sigma=self.gblur_sigma) dinp = torch.as_tensor(dinp).to(self.device).float() # forward pass with amp.autocast(enabled=self.mixed_precision): dout = self.model(dinp) if dmask is None: dloss = self.criterion(dout, dtarget) else: dloss = self.criterion(dout, dtarget, dmask) val_loss.append(dloss.item()) out = dout.detach().cpu() outs.append(out) targets.append(dtarget) images = { 'inp': dinp.cpu().numpy(), 'out': dout.cpu().numpy(), 'target': None if dtarget is None else dtarget.cpu().numpy(), 'fname': batch.get('fname'), } self._put_current_attention_maps_into(images) stats['val_loss'] = np.mean(val_loss) stats['val_loss_std'] = np.std(val_loss) for name, evaluator in self.valid_metrics.items(): mvals = [evaluator(target, out) for target, out in zip(targets, outs)] if np.all(np.isnan(mvals)): stats[name] = np.nan else: stats[name] = np.nanmean(mvals) self.model.train() # Reset model to training mode return stats, images
if __name__ == '__main__': # Demo of Noise2Void training sample generation import matplotlib.pyplot as plt import scipy.misc # co = get_stratified_coords(16, (8, 8, 3)) # print(co) im = scipy.misc.ascent()[::2, ::2] imt = torch.as_tensor(im)[None, None] inp, target, mask = prepare_sample(imt, 1e-3) fig, axes = plt.subplots(ncols=3, constrained_layout=True, figsize=(20, 12)) axes[0].imshow(im, cmap='gray') axes[0].set_title('Original image') axes[1].imshow(mask[0,0]) axes[1].set_title('Mask') axes[2].imshow(inp[0,0], cmap='gray') axes[2].set_title('Manipulated image for Noise2Void training') plt.show()