Source code for elektronn3.modules.lovasz_losses

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Maxim Berman, Martin Drawitsch

# This file is mostly copied from https://github.com/bermanmaxim/LovaszSoftmax.
# Modifications for elektronn3:
# - Support 5D images
# - Avoid zero division in low-precision training by adding a small constant eps in divisions

"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)

"""

from __future__ import print_function, division

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse


eps = 0.0001  # To avoid divisions by zero, esp. in low-precision training


[docs] def lovasz_grad(gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / (union + eps) if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard
[docs] def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): """ IoU for foreground class binary: 1 foreground, 0 background """ if not per_image: preds, labels = (preds,), (labels,) ious = [] for pred, label in zip(preds, labels): intersection = ((label == 1) & (pred == 1)).sum() union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() if not union: iou = EMPTY else: iou = float(intersection) / (union + eps) ious.append(iou) iou = mean(ious) # mean accross images if per_image return 100 * iou
[docs] def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): """ Array of IoU for each (non ignored) class """ if not per_image: preds, labels = (preds,), (labels,) ious = [] for pred, label in zip(preds, labels): iou = [] for i in range(C): if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) intersection = ((label == i) & (pred == i)).sum() union = ((label == i) | ((pred == i) & (label != ignore))).sum() if not union: iou.append(EMPTY) else: iou.append(float(intersection) / (union + eps)) ious.append(iou) ious = map(mean, zip(*ious)) # mean accross images if per_image return 100 * np.array(ious)
# --------------------------- BINARY LOSSES ---------------------------
[docs] def lovasz_hinge(logits, labels, per_image=True, ignore=None): """ Binary Lovasz hinge loss logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) per_image: compute the loss per image instead of per batch ignore: void class id """ if per_image: loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) for log, lab in zip(logits, labels)) else: loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) return loss
[docs] def lovasz_hinge_flat(logits, labels): """ Binary Lovasz hinge loss logits: [P] Variable, logits at each prediction (between -\infty and +\infty) labels: [P] Tensor, binary ground truth labels (0 or 1) ignore: label to ignore """ if len(labels) == 0: # only void pixels, the gradients should be 0 return logits.sum() * 0. signs = 2. * labels.float() - 1. errors = (1. - logits * Variable(signs)) errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] grad = lovasz_grad(gt_sorted) loss = torch.dot(F.relu(errors_sorted), Variable(grad)) return loss
[docs] def flatten_binary_scores(scores, labels, ignore=None): """ Flattens predictions in the batch (binary case) Remove labels equal to 'ignore' """ scores = scores.view(-1) labels = labels.view(-1) if ignore is None: return scores, labels valid = (labels != ignore) vscores = scores[valid] vlabels = labels[valid] return vscores, vlabels
[docs] class StableBCELoss(torch.nn.modules.Module): def __init__(self): super(StableBCELoss, self).__init__()
[docs] def forward(self, input, target): neg_abs = - input.abs() loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() return loss.mean()
[docs] def binary_xloss(logits, labels, ignore=None): """ Binary Cross entropy loss logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) ignore: void class id """ logits, labels = flatten_binary_scores(logits, labels, ignore) loss = StableBCELoss()(logits, Variable(labels.float())) return loss
# --------------------------- MULTICLASS LOSSES ---------------------------
[docs] def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None): """ Multi-class Lovasz-Softmax loss probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) only_present: average only on num_classes present in ground truth per_image: compute the loss per image instead of per batch ignore: void class labels """ if per_image: loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present) for prob, lab in zip(probas, labels)) else: loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present) return loss
[docs] def lovasz_softmax_flat(probas, labels, only_present=False): """ Multi-class Lovasz-Softmax loss probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) only_present: average only on num_classes present in ground truth """ C = probas.size(1) losses = [] for c in range(C): fg = (labels == c).float() # foreground for class c if only_present and fg.sum() == 0: continue errors = (Variable(fg) - probas[:, c]).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) return mean(losses)
[docs] def flatten_probas(probas, labels, ignore=None): """ Flattens predictions in the batch """ C = probas.shape[1] if probas.dim() == 4: probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C elif probas.dim() == 5: # 3D images probas = probas.permute(0, 2, 3, 4, 1).contiguous().view(-1, C) # B * D * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobas, vlabels
[docs] def xloss(logits, labels, ignore=None): """ Cross entropy loss """ return F.cross_entropy(logits, Variable(labels), ignore_index=255)
# --------------------------- HELPER FUNCTIONS ---------------------------
[docs] def mean(l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(np.isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n