# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Martin Drawitsch
"""Loss functions"""
from typing import Sequence, Optional, Tuple, Callable, Union
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from elektronn3.modules.lovasz_losses import lovasz_softmax
[docs]
class CombinedLoss(torch.nn.Module):
"""Defines a loss function as a weighted sum of combinable loss criteria.
Args:
criteria: List of loss criterion modules that should be combined.
weight: Weight assigned to the individual loss criteria (in the same
order as ``criteria``).
device: The device on which the loss should be computed. This needs
to be set to the device that the loss arguments are allocated on.
"""
def __init__(
self,
criteria: Sequence[torch.nn.Module],
weight: Optional[Sequence[float]] = None,
device: torch.device = None
):
super().__init__()
self.criteria = torch.nn.ModuleList(criteria)
self.device = device
if weight is None:
weight = torch.ones(len(criteria))
else:
weight = torch.as_tensor(weight, dtype=torch.float32)
assert weight.shape == (len(criteria),)
self.register_buffer('weight', weight.to(self.device))
[docs]
def forward(self, *args):
loss = torch.tensor(0., device=self.device)
for crit, weight in zip(self.criteria, self.weight):
loss += weight * crit(*args)
return loss
[docs]
class FocalLoss(torch.nn.Module):
"""Focal Loss (https://arxiv.org/abs/1708.02002)
Expects raw outputs, not softmax probs."""
def __init__(self, weight=None, gamma=2., reduction='mean', ignore_index=-100):
super().__init__()
self.gamma = gamma
self.nll = torch.nn.NLLLoss(weight=weight, reduction=reduction, ignore_index=ignore_index)
self.log_softmax = torch.nn.LogSoftmax(1)
[docs]
def forward(self, output, target):
log_prob = self.log_softmax(output)
prob = torch.exp(log_prob)
return self.nll(((1 - prob) ** self.gamma) * log_prob, target)
[docs]
class SoftmaxBCELoss(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.bce = torch.nn.BCELoss(*args, **kwargs)
[docs]
def forward(self, output, target):
probs = torch.nn.functional.softmax(output, dim=1)
return self.bce(probs, target)
# class MultiLabelCrossEntropy(nn.Module):
# def __init__(self, weight=torch.tensor(1.)):
# self.register_buffer('weight', weight)
# def forward(self, output, target):
# assert output.shape == target.shape
# logprobs = torch.nn.functional.log_softmax(output, dim=1)
# wsum = self.weight[None] * torch.sum(-target * logprobs)
# return torch.mean(wsum, dim=1)
[docs]
def global_average_pooling(inp: torch.Tensor) -> torch.Tensor:
if inp.ndim == 5:
return F.adaptive_avg_pool3d(inp, 1)
elif inp.ndim == 4:
return F.adaptive_avg_pool2d(inp, 1)
else:
raise NotImplementedError
[docs]
class GAPTripletMarginLoss(nn.TripletMarginLoss):
"""Same as ``torch.nn.TripletMarginLoss``, but applies global average
pooling to anchor, positive and negative tensors before calculating the
loss."""
[docs]
def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
return super().forward(
global_average_pooling(anchor),
global_average_pooling(positive),
global_average_pooling(negative)
)
[docs]
class MaskedMSELoss(nn.Module):
"""Masked MSE loss where only pixels that are masked are considered.
Expects an optional binary mask as the third argument.
If no mask is supplied (``None``), the loss is equivalent to ``torch.nn.MSELoss``."""
[docs]
@staticmethod
def forward(out, target, mask=None):
if mask is None:
return F.mse_loss(out, target)
err = F.mse_loss(out, target, reduction='none')
err *= mask
loss = err.sum() / mask.sum() # Scale by ratio of masked pixels
return loss
[docs]
class DistanceWeightedMSELoss(nn.Module):
"""Weighted MSE loss for signed euclidean distance transform targets.
By setting ``fg_weight`` to a high value, the errors in foreground
regions are more strongly penalized.
If ``fg_weight=1``, this loss is equivalent to ``torch.nn.MSELoss``.
Requires that targets are transformed with
:py:class:`elektronn3.data.transforms.DistanceTransformTarget`
Per-pixel weights are assigned on the targets as follows:
- each foreground pixel is weighted by ``fg_weight``
- each background pixel is weighted by 1.
"""
def __init__(self, fg_weight=100., mask_borders=40):
super().__init__()
self.fg_weight = fg_weight
self.mask_borders = mask_borders
[docs]
def forward(self, output, target):
mse = nn.functional.mse_loss(output, target, reduction='none')
with torch.no_grad():
# This assumes that the target is in the (-1, 1) value range (tanh)
# and that regions with value <= 0 are foreground.
weight = torch.ones_like(target)
weight[target <= 0] = self.fg_weight
if self.mask_borders is not None: # Mask out invalid regions that come from same-padding
o = self.mask_borders
weight[:, :, :o, :o] = 0.
weight[:, :, target.shape[-2] - o:, target.shape[-1] - o:] = 0.
return torch.mean(weight * mse)
def _channelwise_sum(x: torch.Tensor):
"""Sum-reduce all dimensions of a tensor except dimension 1 (C)"""
reduce_dims = tuple([0] + list(range(x.dim()))[2:]) # = (0, 2, 3, ...)
return x.sum(dim=reduce_dims)
# TODO: Dense weight support
[docs]
def dice_loss(probs, target, weight=1., eps=0.0001, smooth=0.):
# Probs need to be softmax probabilities, not raw network outputs
tsh, psh = target.shape, probs.shape
if tsh == psh: # Already one-hot
onehot_target = target.to(probs.dtype)
elif tsh[0] == psh[0] and tsh[1:] == psh[2:]: # Assume dense target storage, convert to one-hot
onehot_target = torch.zeros_like(probs)
onehot_target.scatter_(1, target.unsqueeze(1), 1)
else:
raise ValueError(
f'Target shape {target.shape} is not compatible with output shape {probs.shape}.'
)
# if weight is None:
# weight = torch.ones(probs.shape[0], dtype=probs.dtype) # (C,)
# if ignore_index is not None:
# weight[:, ignore_index] = 0.
intersection = probs * onehot_target # (N, C, ...)
numerator = 2 * _channelwise_sum(intersection) + smooth # (C,)
denominator = probs + onehot_target # (N, C, ...)
denominator = _channelwise_sum(denominator) + smooth + eps # (C,)
loss_per_channel = 1 - (numerator / denominator) # (C,)
weighted_loss_per_channel = weight * loss_per_channel # (C,)
return weighted_loss_per_channel.mean() # ()
[docs]
class DiceLoss(torch.nn.Module):
"""Generalized Dice Loss, as described in https://arxiv.org/abs/1707.03237.
Works for n-dimensional data. Assuming that the ``output`` tensor to be
compared to the ``target`` has the shape (N, C, D, H, W), the ``target``
can either have the same shape (N, C, D, H, W) (one-hot encoded) or
(N, D, H, W) (with dense class indices, as in
``torch.nn.CrossEntropyLoss``). If the latter shape is detected, the
``target`` is automatically internally converted to a one-hot tensor
for loss calculation.
Args:
apply_softmax: If ``True``, a softmax operation is applied to the
``output`` tensor before loss calculation. This is necessary if
your model does not already apply softmax as the last layer.
If ``False``, ``output`` is assumed to already contain softmax
probabilities.
weight: Weight tensor for class-wise loss rescaling.
Has to be of shape (C,). If ``None``, classes are weighted equally.
smooth: Smoothing term that is added to both the numerator and the
denominator of the dice loss formula.
"""
def __init__(
self,
apply_softmax: bool = True,
weight: Optional[torch.Tensor] = None,
smooth: float = 0.
):
super().__init__()
if apply_softmax:
self.softmax = torch.nn.Softmax(dim=1)
else:
self.softmax = lambda x: x # Identity (no softmax)
self.dice = dice_loss
if weight is None:
weight = torch.tensor(1.)
self.register_buffer('weight', weight)
self.smooth = smooth
[docs]
def forward(self, output, target):
probs = self.softmax(output)
return self.dice(probs, target, weight=self.weight, smooth=self.smooth)
# TODO: There is some low-hanging fruit for performance optimization
[docs]
class FixMatchSegLoss(nn.Module):
"""Self-supervised loss for semi-supervised semantic segmentation training,
very similar to the :math:`l_u` loss proposed in
FixMatch (https://arxiv.org/abs/2001.07685).
The main difference to FixMatch is the kind of augmentations that are used
for consistency regularization. In FixMatch, so-called
"strong augmentations" are applied to the (already "weakly augmented")
inputs. Most of these strong augmentations only work for image-level
classification.
In ``FMSegLoss``, only simple, easily reversible geometric augmentations
are used currently
(random xy(z) flipping and random xy rotation in 90 degree steps).
TODO: Add more augmentations
This loss combines two different well-established semi-supervised learning
techniques:
- consistency regularization: consistency (equivariance) against random
flipping and random rotation augmentatations is enforced
- pseudo-label training: model argmax predictions are treated as targets
for a pseudo-supervised cross-entropy training loss
This only works for settings where argmax makes sense (not suitable for
regression) and can be disabled with ``enable_psuedo_label=False``.
Args:
model: Neural network model to be trained.
scale: Scalar factor to be multiplied with the loss to adjust its
magnitude. (If this loss is combined with a standard supervised
cross entropy, ``scale`` corresponds to the lambda_u hyperparameter
in FixMatch
enable_pseudo_label: If ``enable_pseudo_label=True``, the inner loss is
the cross entropy between the argmax pseudo label tensor computed
from the weakly augmented input and the softmax model output on the
strongly augmented input. Since this internally uses
``torch.nn.CrossEntropyLoss``, the ``model`` is expected to give
raw, unsoftmaxed outputs.
This only works for settings where computing the argmax and softmax
on the outputs makes sense (so classification, not regression).
If ``enable_pseudo_label=False``, a mean squared error regression
loss is computed directly on the difference between the two model
outputs, without computing or using pseudo-labels.
In this case, the loss is equivalent to the ``R`` loss proposed in
"Transformation Consistent Self-ensembling Model for
Semi-supervised Medical Image Segmentation"
(https://arxiv.org/abs/1903.00348).
This non-pseudo-label variant of the loss can also be used for
pixel-level regression training.
confidence_thresh: (Only applies if ``enable_pseudo_label=True``.)
The confidence threshold that determines how
confident the model has to be in each output element's
classification for it to contribute to the loss. All output
elements where none of the softmax class probs exceed this
threshold are masked out from the loss calculation and the resulting
loss is set to 0. In the FixMatch paper, this hyperparameter is
called tau.
ce_weight: (Only applies if ``enable_pseudo_label=True``.)
Class weight tensor for the inner cross-entropy loss. Should be
the same as the weight for the supervised cross-entropy loss.
"""
def __init__(
self,
model: nn.Module,
scale: float = 1.,
enable_pseudo_label: bool = True,
confidence_thresh: float = 0.9,
ce_weight=None
):
super().__init__()
self.model = model
self.scale = scale
self.enable_pseudo_label = enable_pseudo_label
self.confidence_thresh = confidence_thresh
if self.enable_pseudo_label:
self.criterion = nn.CrossEntropyLoss(weight=ce_weight, ignore_index=-100)
else:
self.criterion = nn.MSELoss()
[docs]
@staticmethod
def get_random_augmenters(
ndim: int
) -> Tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor]]:
"""Produce a pair of functions ``augment, reverse_augment``, where
the ``augment`` function applies a random augmentation to a torch
tensor and the ``reverse_augment`` function performs the reverse
aumentations if applicable (i.e. for geometrical transformations)
so pixel-level loss calculation is still correct).
Note that all augmentations are performed on the compute device that
holds the input, so generally on the GPU.
"""
# Random rotation angle (in 90 degree steps)
k90 = torch.randint(0, 4, ()).item()
# Get a random selection of spatial dims (ranging from [] to [2, 3, ..., example.ndim - 1]
flip_dims_binary = torch.randint(0, 2, (ndim - 2,))
flip_dims = (torch.nonzero(flip_dims_binary, as_tuple=False).squeeze(1) + 2).tolist()
@torch.no_grad()
def augment(x: torch.Tensor) -> torch.Tensor:
x = torch.rot90(x, +k90, (-1, -2))
if len(flip_dims) > 0:
x = torch.flip(x, flip_dims)
# # Uncomment to enable additional random brightness and contrast augmentations
# contrast_std = 0.1
# brightness_std = 0.1
# a = torch.randn(x.shape[:2], device=x.device, dtype=x.dtype) * contrast_std + 1.0
# b = torch.randn(x.shape[:2], device=x.device, dtype=x.dtype) * brightness_std
# for n in range(x.shape[0]):
# for c in range(x.shape[1]):
# # Formula based on tf.image.{adjust_contrast,adjust_brightness}
# # See https://www.tensorflow.org/api_docs/python/tf/image
# m = torch.mean(x[n, c])
# x[n, c] = a[n, c] * (x[n, c] - m) + m + b[n, c]
# # Uncomment to enable additional additive gaussian noise augmentations
# agn_std = 0.1
# x.add_(torch.randn_like(x).mul_(agn_std))
return x
@torch.no_grad()
def reverse_augment(x: torch.Tensor) -> torch.Tensor:
if len(flip_dims) > 0: # Check is necessary only on cuda
x = torch.flip(x, flip_dims)
x = torch.rot90(x, -k90, (-1, -2))
return x
return augment, reverse_augment
[docs]
def forward(self, inp: torch.Tensor) -> torch.Tensor:
augment, reverse_augment = self.get_random_augmenters(ndim=inp.ndim)
aug = augment(inp)
out = self.model(inp)
aug_out = self.model(aug)
aug_out_reversed = reverse_augment(aug_out)
if self.enable_pseudo_label:
with torch.no_grad():
# We need softmax outputs for thresholding
out = torch.softmax(out, 1)
omax, pseudo_label = torch.max(out, dim=1)
# Ignore loss on outputs that are lower than the confidence threshold
mask = omax < self.confidence_thresh
# Assign special ignore value to all masked elements
pseudo_label[mask] = self.criterion.ignore_index
loss = self.criterion(aug_out_reversed, pseudo_label)
else:
loss = self.criterion(aug_out_reversed, out)
scaled_loss = self.scale * loss
return scaled_loss
# TODO: Rename and clean up
[docs]
def norpf_dice_loss(probs, target, weight=1., class_weight=1.):
# Probs need to be softmax probabilities, not raw network outputs
tsh, psh = target.shape, probs.shape
if tsh == psh: # Already one-hot
onehot_target = target.to(probs.dtype)
elif tsh[0] == psh[0] and tsh[1:] == psh[2:]: # Assume dense target storage, convert to one-hot
onehot_target = torch.zeros_like(probs)
onehot_target.scatter_(1, target.unsqueeze(1), 1)
else:
raise ValueError(
f'Target shape {target.shape} is not compatible with output shape {probs.shape}.'
)
# if weight is None:
# weight = torch.ones(probs.shape[0], dtype=probs.dtype) # (C,)
# if ignore_index is not None:
# weight[:, ignore_index] = 0.
if weight.sum() == 0:
return probs.sum() * 0
ignore_mask = (1 - onehot_target[0][-1]).view(1,1,*probs.shape[2:]) # inverse ignore
bg_probs = 1 - probs
bg_target = 1 - onehot_target
global_weight = (class_weight > 0).type(probs.dtype)
positive_target_mask = (weight.view(1,-1,1,1,1) * onehot_target)[0][1:-1].sum(dim=0).view(1,1,*probs.shape[2:]) # weighted targets w\ background and ignore
weight = weight * global_weight
dense_weight = weight.view(1,-1,1,1,1)
target_mask_empty = ((positive_target_mask * ignore_mask).sum(dim=(0,2,3,4)) == 0).type(probs.dtype)
target_empty = ((onehot_target * ignore_mask).sum(dim=(0,2,3,4)) == 0).type(probs.dtype)
bg_target_empty = ((bg_target * ignore_mask).sum(dim=(0,2,3,4)) == 0).type(probs.dtype)
# complete background for weighted classes and target of weighted classes as background for unweighted classes
needs_positive_target_mark = (dense_weight.sum() == 0).type(probs.dtype)
bg_mask = torch.ones_like(bg_probs) * dense_weight + needs_positive_target_mark * positive_target_mask * global_weight.view(1,-1,1,1,1)
# make num/denom 1 for unweighted classes and classes with no target
intersection = probs * onehot_target * ignore_mask * dense_weight # (N, C, ...)
intersection2 = bg_probs * bg_target * ignore_mask * bg_mask # (N, C, ...)
denominator = (probs + onehot_target) * ignore_mask * dense_weight # (N, C, ...)
denominator2 = (bg_probs + bg_target) * ignore_mask * bg_mask # (N, C, ...)
numerator = 2 * class_weight * _channelwise_sum(intersection) # (C,)
numerator2 = 2 * _channelwise_sum(intersection2) # (C,)
denominator = class_weight * _channelwise_sum(denominator) # (C,)
denominator2 = _channelwise_sum(denominator2) # (C,)
no_tp = (numerator == 0).type(probs.dtype)
# workarounds for divide by zero
# unweighted classes get DSC=1
numerator += (1 - weight)
denominator += (1 - weight)
bg_mask_empty = ((bg_mask).sum(dim=(0,2,3,4)) == 0).type(probs.dtype)
numerator2 *= 1 - bg_mask_empty
numerator2 += bg_mask_empty
denominator2 *= 1 - bg_mask_empty
denominator2 += bg_mask_empty
# when there is no target, DSC has no meaning and is therefore set to 1 as well
numerator *= 1 - target_empty
numerator += target_empty
denominator *= 1 - target_empty
denominator += target_empty
numerator2 *= 1 - bg_target_empty
numerator2 += bg_target_empty
denominator2 *= 1 - bg_target_empty
denominator2 += bg_target_empty
if (denominator == 0).sum() > 0 or (denominator2 == 0).sum() > 0:
print(denominator, denominator2)
import IPython
IPython.embed()
#numerator2 += (numerator2 == 0).type(numerator2.dtype) # no tp background
#denominator2 += (denominator2 == 0).type(denominator2.dtype) # 100% tp foreground
#denominator += (denominator == 0).type(denominator.dtype) # ???¿¿¿
#loss_per_channel = 1 - numerator / denominator # (C,)
#loss_per_channel = 1 - ((numerator * denominator2 + denominator * numerator2) / (2 * denominator * denominator2)) # (C,)
#loss_per_channel = 1 - (numerator * denominator2 + denominator * numerator2) / (2 * denominator * denominator2) # (C,)
loss_per_channel = 1 + no_tp - (numerator/denominator + no_tp * numerator2/denominator2) # (C,)
#loss_per_channel = 1 - (numerator + numerator2 + target_mask_empty * (1 - weight)) / (denominator + denominator2 + target_mask_empty * (1 - weight)) # (C,)
#weighted_loss = 1 - (numerator.sum() + numerator2.sum())/(denominator.sum() + denominator2.sum())
#weighted_loss = 1 - (numerator[1:-1].sum() + numerator2[1:-1].sum()) / (denominator[1:-1].sum() + denominator2[1:-1].sum()) # (C,)
#weighted_loss = (weight[:-1] * loss_per_channel[:-1]).sum() / weight[:-1].sum() # (C,)
# normalize loss to [0, 1]
#weighted_loss = loss_per_channel[1:-1].sum() / ((loss_per_channel[1:-1] > 0).sum() + (loss_per_channel[1:-1].sum() == 0)) # (C,)
#weighted_loss *= 2
#print(loss_per_channel)
weighted_loss = loss_per_channel[1:-1].sum() / (class_weight[1:-1] > 0).sum()
#weighted_loss = loss_per_channel[1:-1].sum()
if torch.isnan(weighted_loss).sum() or (weighted_loss > 1).sum():
#if torch.isnan(weighted_loss):
print(loss_per_channel)
import IPython
IPython.embed()
return weighted_loss # ()
# TODO: Rename and clean up
[docs]
class NorpfDiceLoss(torch.nn.Module):
"""Generalized Dice Loss, as described in https://arxiv.org/abs/1707.03237,
Works for n-dimensional data. Assuming that the ``output`` tensor to be
compared to the ``target`` has the shape (N, C, D, H, W), the ``target``
can either have the same shape (N, C, D, H, W) (one-hot encoded) or
(N, D, H, W) (with dense class indices, as in
``torch.nn.CrossEntropyLoss``). If the latter shape is detected, the
``target`` is automatically internally converted to a one-hot tensor
for loss calculation.
Args:
apply_softmax: If ``True``, a softmax operation is applied to the
``output`` tensor before loss calculation. This is necessary if
your model does not already apply softmax as the last layer.
If ``False``, ``output`` is assumed to already contain softmax
probabilities.
weight: Weight tensor for class-wise loss rescaling.
Has to be of shape (C,). If ``None``, classes are weighted equally.
"""
def __init__(self, apply_softmax=True, weight=torch.tensor(1.), class_weight=torch.tensor(1.)):
super().__init__()
if apply_softmax:
self.softmax = torch.nn.Softmax(dim=1)
else:
self.softmax = lambda x: x # Identity (no softmax)
self.dice = norpf_dice_loss
self.register_buffer('weight', weight)
self.register_buffer('class_weight', class_weight)
[docs]
def forward(self, output, target):
probs = self.softmax(output)
return self.dice(probs, target, weight=self.weight, class_weight=self.class_weight)
[docs]
class LovaszLoss(torch.nn.Module):
"""https://arxiv.org/abs/1705.08790"""
def __init__(self, apply_softmax=True):
super().__init__()
if apply_softmax:
self.softmax = torch.nn.Softmax(dim=1)
else:
self.softmax = lambda x: x # Identity (no softmax)
# lovasz_softmax works on softmax probs, so we still have to apply
# softmax before passing probs to it
self.lovasz = lovasz_softmax
[docs]
def forward(self, output, target):
probs = self.softmax(output)
return self.lovasz(probs, target)
[docs]
class ACLoss(torch.nn.Module):
"""Active Contour loss
http://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Learning_Active_Contour_Models_for_Medical_Image_Segmentation_CVPR_2019_paper.pdf
Supports 2D and 3D data, as long as all spatial dimensions have the same
size and there are only two output channels.
Modifications:
- Using mean instead of sum for reductions to avoid size dependency.
- Instead of the proposed λ loss component weighting (which leads to
exploding loss magnitudes for high λ values), a relative weight
``region_weight`` is used to balance the components:
``ACLoss = (1 - region_weight) * length_term + region_weight * region_term``
"""
def __init__(self, num_classes: int, region_weight: float = 0.5):
assert 0. <= region_weight <= 1., 'region_weight must be between 0 and 1'
self.num_classes = num_classes
self.region_weight = region_weight
super().__init__()
[docs]
@staticmethod
def get_length(output):
if output.ndim == 4:
dy = output[:, :, 1:, :] - output[:, :, :-1, :] # y gradient (B, C, H-1, W)
dx = output[:, :, :, 1:] - output[:, :, :, :-1] # x gradient (B, C, H, W-1)
dy = dy[:, :, 1:, :-2] ** 2 # (B, C, H-2, W-2)
dx = dx[:, :, :-2, 1:] ** 2 # (B, C, H-2, W-2)
delta_pred = torch.abs(dy + dx)
elif output.ndim == 5:
assert output.shape[3] == output.shape[4], 'All spatial dims must have the same size'
dz = output[:, :, 1:, :, :] - output[:, :, :-1, :, :] # z gradient (B, C, D-1, H, W)
dy = output[:, :, :, 1:, :] - output[:, :, :, :-1, :] # h gradient (B, C, D, H-1, W)
dx = output[:, :, :, :, 1:] - output[:, :, :, :, :-1] # w gradient (B, C, D, H, W-1)
dz = dz[:, :, 1:, :-2, :-2] ** 2 # (B, C, D-2, H-2, W-2)
dy = dy[:, :, :-2, 1:, :-2] ** 2 # (B, C, D-2, H-2, W-2)
dx = dx[:, :, :-2, :-2, 1:] ** 2 # (B, C, D-2, H-2, W-2)
delta_pred = torch.abs(dz + dy + dx)
length = torch.mean(torch.sqrt(delta_pred + 1e-6))
return length
[docs]
@staticmethod
def get_region(output, target):
region_in = torch.mean(output * (target - 1.) ** 2.)
region_out = torch.mean((1 - output) * target ** 2.)
return region_in + region_out
[docs]
def forward(self, output, target):
assert output.shape[2] == output.shape[3], 'All spatial dims must have the same size'
if target.ndim == output.ndim - 1:
target = torch.nn.functional.one_hot(target, self.num_classes).transpose(1, -1)
length_term = self.get_length(output) if self.region_weight < 1. else 0.
region_term = self.get_region(output, target) if self.region_weight > 0. else 0.
loss = (1 - self.region_weight) * length_term + self.region_weight * region_term
return loss
[docs]
class MixedCombinedLoss(torch.nn.Module):
"""
Defines a loss function as a weighted sum of combinable loss criteria for multi-class classification with only
single class ground truths.
For each voxel, we construct a 2 channel output after the softmax:
channel 0: background (actual background + all but one classes) = (1-channel 1)
channel 1: foreground (the one class to which the target corresponds)
Args:
class_weight: a manual rescaling weight given to each
class.
criteria: List of loss criterion modules that should be combined.
criteria_weight: Weight assigned to the individual loss criteria (in the same
order as ``criteria``).
device: The device on which the loss should be computed. This needs
to be set to the device that the loss arguments are allocated on.
eps:
"""
def __init__(self, class_weight, criteria, criteria_weight, device, eps=1e-10, **kwargs):
super(MixedCombinedLoss, self).__init__()
self.softmax = torch.nn.Softmax(dim=1)
self.class_weight = class_weight
self.criteria = torch.nn.ModuleList(criteria)
self.device = device
self.eps = eps
if criteria_weight is None:
weight = torch.ones(len(criteria))
else:
weight = torch.as_tensor(criteria_weight, dtype=torch.float32)
assert weight.shape == (len(criteria),)
self.register_buffer('weight', weight.to(self.device))
[docs]
def forward(self, output_direct, target, target_class):
assert all([len(torch.unique(target_sample[0])) <= 2 for target_sample in
target]) # background and that class for each sample in the batch
modified_target = torch.zeros_like(target)
modified_target[target != 0] = 1
logit_max = output_direct.max(axis=1)[0].unsqueeze(1)
output_shifted = output_direct - logit_max # subtract max for numerical stability
# for dice
softmax_output = self.softmax(output_shifted)
softmax_output = (1 - self.eps) * softmax_output + self.eps # eps for numerical stability
softmax_output = softmax_output[(range(softmax_output.shape[0]), target_class)].unsqueeze(1)
softmax_output = torch.cat([1 - softmax_output, softmax_output], dim=1)
# for crossentropy: compute in log softmax for the two classes for numerical stability
exp_output = output_shifted.exp()
exp_output_sum = exp_output.sum(axis=1).unsqueeze(1)
exp_output_sum_log = exp_output_sum.log()
# foreground i: log_softmax_i = -log(softmax(x)[i]) = x[i] - log(\sum_j exp(x[j]))
log_softmax_i = output_shifted - exp_output_sum_log
num_classes = output_direct.shape[1]
idx = [np.arange(num_classes) != i for i in range(num_classes)] # todo: only compute for target class
exp_output_sum_minus_i = torch.stack([exp_output[:, idx[k]].sum(axis=1) for k in range(num_classes)], dim=1)
# assert torch.allclose(exp_output_sum_minus_i, exp_output_sum - exp_output)
# background -i: log_softmax_minus_i = log(1-softmax(x)[i]) = log(\sum_j!=i exp(x[j])) - log(\sum_j exp(x[j]))
log_softmax_minus_i = (exp_output_sum_minus_i).log() - exp_output_sum_log
log_softmax_i_output = log_softmax_i[range(softmax_output.shape[0]), target_class].unsqueeze(1)
log_softmax_minus_i_output = log_softmax_minus_i[range(softmax_output.shape[0]), target_class].unsqueeze(1)
log_softmax_output = torch.cat([log_softmax_minus_i_output, log_softmax_i_output], dim=1)
loss = torch.tensor(0., device=softmax_output.device)
for crit, crit_weight in zip(self.criteria, self.weight):
for i in range(softmax_output.shape[0]): # todo: not element-wise; process full batch
if isinstance(crit, torch.nn.NLLLoss):
crit_loss = crit(log_softmax_output[i].unsqueeze(0), modified_target[i].unsqueeze(0)).mean()
elif isinstance(crit, DiceLoss):
crit_loss = crit(softmax_output[i].unsqueeze(0), modified_target[i].unsqueeze(0))
else:
raise NotImplementedError()
loss += crit_loss * crit_weight * self.class_weight[target_class[i]]
return loss
##### ALTERNATIVE VERSIONS OF DICE LOSS #####
# Version with features that are untested and currently not needed
# Based on https://discuss.pytorch.org/t/one-hot-encoding-with-autograd-dice-loss/9781/5
def __dice_loss_with_cool_extra_features(output, target, weights=None, ignore_index=None):
eps = 0.0001
encoded_target = torch.zeros_like(output)
if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
if weights is None:
weights = 1
intersection = output * encoded_target
numerator = 2 * _channelwise_sum(intersection)
denominator = output + encoded_target
if ignore_index is not None:
denominator[mask] = 0
denominator = _channelwise_sum(denominator) + eps
loss_per_channel = weights * (1 - (numerator / denominator))
return loss_per_channel.sum() / output.shape[1]
# Very simple version. Only for binary classification. Just included for testing.
# Note that the smooth value is set to 0 and eps is introduced instead, to make it comparable.
# Based on https://github.com/pytorch/pytorch/issues/1249#issuecomment-305088398
def __dice_loss_binary(output, target, smooth=0, eps=0.0001):
onehot_target = torch.zeros_like(output)
onehot_target.scatter_(1, target.unsqueeze(1), 1)
iflat = output.view(-1)
tflat = onehot_target.view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2 * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth + eps))