Source code for elektronn3.training.metrics

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Martin Drawitsch

# TODO: Update docs to show Evaluator
"""Metrics and tools for evaluating neural network predictions

References:

- https://en.wikipedia.org/wiki/Confusion_matrix
- https://stats.stackexchange.com/questions/273537/f1-dice-score-vs-iou
- http://scikit-learn.org/stable/modules/model_evaluation.html

.. note::

    ``sklearn.metrics`` has a lot of alternative implementations that can be
    compared with these here and could be used as inspiration for future work
    (http://scikit-learn.org/stable/modules/classes.html#classification-metrics).

    For example, to get the equivalent output to
    ``elektronn3.training.metrics.recall(target, pred, num_classes=2, mean=False) / 100``,
    from scikit-learn, you can compute
    ``sklearn.metrics.recall_score(target.view(-1).cpu().numpy(), pred.view(-1).cpu().numpy(), average=None).astype(np.float32)``.


    For most metrics, we don't use scikit-learn directly in this module for
    performance reasons:

    - PyTorch allows us to calculate metrics directly on GPU
    - We LRU-cache confusion matrices for cheap calculation of multiple metrics
"""

from functools import lru_cache
from typing import Callable, Optional
from unicodedata import name

import sklearn.metrics
import torch
import numpy as np


eps = 0.0001  # To avoid divisions by zero

# TODO: Tests would make a lot of sense here.


[docs] @lru_cache(maxsize=128) def confusion_matrix( target: torch.LongTensor, pred: torch.LongTensor, num_classes: int = 2, dtype: torch.dtype = torch.float32, device: torch.device = torch.device('cpu'), nan_when_empty: bool = True, ignore: Optional[int] = None, ) -> torch.Tensor: """ Calculate per-class confusion matrix. Uses an LRU cache, so subsequent calls with the same arguments are very cheap. Args: pred: Tensor with predicted class values target: Ground truth tensor with true class values num_classes: Number of classes that the target can assume. E.g. for binary classification, this is 2. Classes are expected to start at 0 dtype: ``torch.dtype`` to be used for calculation and output. ``torch.float32`` is used as default because it is robust against overflows and can be used directly in true divisions without re-casting. device: PyTorch device on which to store the confusion matrix nan_when_empty: If ``True`` (default), the confusion matrix will be filled with NaN values for each channel of which there are no positive entries in the ``target`` tensor. ignore: Index to be ignored for cm calculation Returns: Confusion matrix ``cm``, with shape ``(num_classes, 4)``, where each row ``cm[c]`` contains (in this order) the count of - true positives - true negatives - false positives - false negatives of ``pred`` w.r.t. ``target`` and class ``c``. E.g. ``cm[1][2]`` contains the number of false positive predictions of class ``1``. If ``nan_when_empty`` is enabled and there are no positive elements of class ``1`` in ``target``, ``cm[1]`` will instead be filled with NaN values. """ cm = torch.empty(num_classes, 4, dtype=dtype, device=device) for c in range(num_classes): pos_pred = pred == c neg_pred = ~pos_pred pos_target = target == c if ignore is not None: ign_target = target == ignore else: ign_target = False # Makes `& ~ign_target` a no-op # Manual conversion to Tensor because of a type promotion regression in PyTorch 1.5 ign_target = torch.tensor(ign_target, dtype=torch.bool, device=device) neg_target = ~pos_target true_pos = (pos_pred & pos_target & ~ign_target).sum(dtype=dtype) true_neg = (neg_pred & neg_target & ~ign_target).sum(dtype=dtype) false_pos = (pos_pred & neg_target & ~ign_target).sum(dtype=dtype) false_neg = (neg_pred & pos_target & ~ign_target).sum(dtype=dtype) cm[c] = torch.tensor([true_pos, true_neg, false_pos, false_neg]) if nan_when_empty and pos_target.sum(dtype=dtype) == 0: cm[c] = torch.tensor([float('nan')] * 4) return cm
[docs] def precision(target, pred, num_classes=2, mean=True, ignore=None): """Precision metric (in %)""" cm = confusion_matrix(target, pred, num_classes=num_classes, ignore=ignore) tp, tn, fp, fn = cm.transpose(0, 1) # Transposing to put class axis last # Compute metrics for each class simulataneously prec = tp / (tp + fp + eps) # Per-class precision if mean: prec = prec.mean().item() return prec * 100
[docs] def recall(target, pred, num_classes=2, mean=True, ignore=None): """Recall metric a.k.a. sensitivity a.k.a. hit rate (in %)""" cm = confusion_matrix(target, pred, num_classes=num_classes, ignore=ignore) tp, tn, fp, fn = cm.transpose(0, 1) # Transposing to put class axis last rec = tp / (tp + fn + eps) # Per-class recall if mean: rec = rec.mean().item() return rec * 100
[docs] def accuracy(target, pred, num_classes=2, mean=True, ignore=None): """Accuracy metric (in %)""" cm = confusion_matrix(target, pred, num_classes=num_classes, ignore=ignore) tp, tn, fp, fn = cm.transpose(0, 1) # Transposing to put class axis last acc = (tp + tn) / (tp + tn + fp + fn + eps) # Per-class accuracy if mean: acc = acc.mean().item() return acc * 100
[docs] def dice_coefficient(target, pred, num_classes=2, mean=True, ignore=None): """Sørensen–Dice coefficient a.k.a. DSC a.k.a. F1 score (in %)""" cm = confusion_matrix(target, pred, num_classes=num_classes, ignore=ignore) tp, tn, fp, fn = cm.transpose(0, 1) # Transposing to put class axis last dsc = 2 * tp / (2 * tp + fp + fn + eps) # Per-class (Sørensen-)Dice similarity coefficient if mean: dsc = dsc.mean().item() return dsc * 100
[docs] def iou(target, pred, num_classes=2, mean=True, ignore=None): """IoU (Intersection over Union) a.k.a. IU a.k.a. Jaccard index (in %)""" cm = confusion_matrix(target, pred, num_classes=num_classes, ignore=ignore) tp, tn, fp, fn = cm.transpose(0, 1) # Transposing to put class axis last iu = tp / (tp + fp + fn + eps) # Per-class Intersection over Union if mean: iu = iu.mean().item() return iu * 100
[docs] def auroc(target, probs, mean=True): """ Area under Curve (AuC) of the ROC curve (in %). .. note:: This implementation uses scikit-learn on the CPU to do the heavy lifting, so it's relatively slow (one call can take about 1 second for typical inputs). """ assert probs.dim() == target.dim() + 1 num_classes = probs.shape[1] # target: (N, [D,], H, W) -> (N*[D,]*H*W,) target_npflat = target.view(-1).cpu().numpy() # probs: (N, C, [D,], H, W) -> (C, N*[D,]*H*W) probs_npflat = probs.transpose(1, 0).view(num_classes, -1).cpu().numpy() auc = torch.empty(num_classes) # Direct roc_auc_score() computation with multi-class arrays can take # hours, so split this into binary calculations manually here by looping # through classes: for c in range(num_classes): t = target_npflat == c # 1 where target is c, 0 everywhere else p = probs_npflat[c] # probs of class c auc[c] = sklearn.metrics.roc_auc_score(t, p) if mean: auc = auc.mean().item() return auc * 100
[docs] def average_precision(target, probs, mean=True): """Average precision (AP) metric based on PR curves (in %). .. note:: This implementation uses scikit-learn on the CPU to do the heavy lifting, so it's relatively slow (one call can take about 1 second for typical inputs). """ assert probs.dim() == target.dim() + 1 num_classes = probs.shape[1] # target: (N, [D,], H, W) -> (N*[D,]*H*W,) target_npflat = target.view(-1).cpu().numpy() # probs: (N, C, [D,], H, W) -> (C, N*[D,]*H*W) probs_npflat = probs.transpose(1, 0).view(num_classes, -1).cpu().numpy() ap = torch.empty(num_classes) # Direct average_precision_score() computation with multi-class arrays can take # hours, so split this into binary calculations manually here by looping # through classes: for c in range(num_classes): t = target_npflat == c # 1 where target is c, 0 everywhere else p = probs_npflat[c] # probs of class c ap[c] = sklearn.metrics.average_precision_score(t, p) if mean: ap = ap.mean().item() return ap * 100
@lru_cache(maxsize=128) def _softmax(x, dim=1): return torch.nn.functional.softmax(x, dim) @lru_cache(maxsize=128) def _argmax(x, dim=1): return x.argmax(dim) # Helper for multi-class metric construction
[docs] def channel_metric(metric, c, num_classes, argmax=True): """Returns an evaluator that calculates the ``metric`` and selects its value for channel ``c``. Example: >>> from elektronn3.training import metrics >>> num_classes = 5 # Example. Depends on model and data set >>> # Metric evaluator dict that registers DSCs of all output channels. >>> # You can pass it to elektronn3.training.Trainer as the ``valid_metrics`` >>> # argument to make it log these values. >>> dsc_evaluators = { ... f'val_DSC_c{c}': channel_metric( ... metrics.dice_coefficient, ... c=c, num_classes=num_classes ... ) ... for c in range(num_classes) ... } """ def evaluator(target, out): pred = _argmax(out) if argmax else out m = metric(target, pred, num_classes=num_classes, mean=False) return m[c] return evaluator
# Metric evaluator shortcuts for raw network outputs in binary classification # tasks ("bin_*"). "Raw" means not softmaxed or argmaxed. # These are deprecated and will be removed later. Use Evaluators instead.
[docs] def bin_precision(target, out): pred = _argmax(out) return precision( target, pred, num_classes=2, mean=False )[1] # Take only the score for class 1
[docs] def bin_recall(target, out): pred = _argmax(out) return recall( target, pred, num_classes=2, mean=False )[1] # Take only the score for class 1
[docs] def bin_accuracy(target, out): pred = _argmax(out) return accuracy( target, pred, num_classes=2, mean=False )[1] # Take only the score for class 1
[docs] def bin_dice_coefficient(target, out): pred = _argmax(out) return dice_coefficient( target, pred, num_classes=2, mean=False )[1] # Take only the score for class 1
[docs] def bin_iou(target, out): pred = _argmax(out) return iou( target, pred, num_classes=2, mean=False )[1] # Take only the score for class 1
[docs] def bin_average_precision(target, out): probs = _softmax(out) return average_precision( target, probs, mean=False )[1] # Take only the score for class 1
[docs] def bin_auroc(target, out): probs = _softmax(out) return auroc( target, probs, mean=False )[1] # Take only the score for class 1
[docs] class Evaluator: name: str = 'generic' def __init__( self, metric_fn: Callable, index: Optional[int] = None, ignore: Optional[int] = None, self_supervised: Optional[bool] = False ): self.metric_fn = metric_fn self.index = index self.ignore = ignore self.num_classes = None self.self_supervised = self_supervised def __call__(self, target: torch.Tensor, out: torch.Tensor) -> float: # self supervised training if self.self_supervised: m = self.metric_fn(target, out) return m # aggregation handled by sklearn metric # supervised training else: if self.num_classes is None: self.num_classes = out.shape[1] # print(self.num_classes) pred = _argmax(out) m = self.metric_fn(target, pred, self.num_classes, mean=False, ignore=self.ignore) if self.index is None: return m.mean().item() return m[self.index].item()
[docs] class Accuracy(Evaluator): name = 'accuracy' def __init__(self, *args, **kwargs): super().__init__(accuracy, *args, **kwargs)
[docs] class Precision(Evaluator): name = 'precision' def __init__(self, *args, **kwargs): super().__init__(precision, *args, **kwargs)
[docs] class Recall(Evaluator): name = 'recall' def __init__(self, *args, **kwargs): super().__init__(recall, *args, **kwargs)
[docs] class IoU(Evaluator): name = 'IoU' def __init__(self, *args, **kwargs): super().__init__(iou, *args, **kwargs)
[docs] class DSC(Evaluator): name = 'DSC' def __init__(self, *args, **kwargs): super().__init__(dice_coefficient, *args, **kwargs)
[docs] class AveragePrecision(Evaluator): name = 'AP' def __init__(self, *args, **kwargs): super().__init__(sklearn.metrics.average_precision_score, *args, **kwargs)
[docs] class AUROC(Evaluator): name = 'AUROC' def __init__(self, *args, **kwargs): super().__init__(sklearn.metrics.roc_auc_score, *args, **kwargs)
[docs] class NMI(Evaluator): name = 'NMI' def __init__(self, *args, **kwargs): super().__init__(sklearn.metrics.v_measure_score, *args, **kwargs)
[docs] class AMI(Evaluator): name = 'AMI' def __init__(self, *args, **kwargs): super().__init__(sklearn.metrics.adjusted_mutual_info_score, *args, **kwargs)
[docs] class SilhouetteScore(Evaluator): name = 'silhouette_score' def __init__(self, *args, **kwargs): super().__init__(sklearn.metrics.silhouette_score, *args, **kwargs)
[docs] class ARI(Evaluator): name = 'ARI' def __init__(self, *args, **kwargs): super().__init__(sklearn.metrics.adjusted_rand_score, *args, **kwargs)