elektronn3.training.metrics module

Metrics and tools for evaluating neural network predictions

References:

Note

sklearn.metrics has a lot of alternative implementations that can be compared with these here and could be used as inspiration for future work (http://scikit-learn.org/stable/modules/classes.html#classification-metrics).

For example, to get the equivalent output to elektronn3.training.metrics.recall(target, pred, num_classes=2, mean=False) / 100, from scikit-learn, you can compute sklearn.metrics.recall_score(target.view(-1).cpu().numpy(), pred.view(-1).cpu().numpy(), average=None).astype(np.float32).

For most metrics, we don’t use scikit-learn directly in this module for performance reasons:

  • PyTorch allows us to calculate metrics directly on GPU

  • We LRU-cache confusion matrices for cheap calculation of multiple metrics

class elektronn3.training.metrics.AMI(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'AMI'
class elektronn3.training.metrics.ARI(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'ARI'
class elektronn3.training.metrics.AUROC(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'AUROC'
class elektronn3.training.metrics.Accuracy(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'accuracy'
class elektronn3.training.metrics.AveragePrecision(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'AP'
class elektronn3.training.metrics.DSC(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'DSC'
class elektronn3.training.metrics.Evaluator(metric_fn, index=None, ignore=None, self_supervised=False)[source]

Bases: object

name: str = 'generic'
class elektronn3.training.metrics.IoU(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'IoU'
class elektronn3.training.metrics.NMI(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'NMI'
class elektronn3.training.metrics.Precision(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'precision'
class elektronn3.training.metrics.Recall(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'recall'
class elektronn3.training.metrics.SilhouetteScore(*args, **kwargs)[source]

Bases: Evaluator

name: str = 'silhouette_score'
elektronn3.training.metrics.accuracy(target, pred, num_classes=2, mean=True, ignore=None)[source]

Accuracy metric (in %)

elektronn3.training.metrics.auroc(target, probs, mean=True)[source]

Area under Curve (AuC) of the ROC curve (in %).

Note

This implementation uses scikit-learn on the CPU to do the heavy lifting, so it’s relatively slow (one call can take about 1 second for typical inputs).

elektronn3.training.metrics.average_precision(target, probs, mean=True)[source]

Average precision (AP) metric based on PR curves (in %).

Note

This implementation uses scikit-learn on the CPU to do the heavy lifting, so it’s relatively slow (one call can take about 1 second for typical inputs).

elektronn3.training.metrics.bin_accuracy(target, out)[source]
elektronn3.training.metrics.bin_auroc(target, out)[source]
elektronn3.training.metrics.bin_average_precision(target, out)[source]
elektronn3.training.metrics.bin_dice_coefficient(target, out)[source]
elektronn3.training.metrics.bin_iou(target, out)[source]
elektronn3.training.metrics.bin_precision(target, out)[source]
elektronn3.training.metrics.bin_recall(target, out)[source]
elektronn3.training.metrics.channel_metric(metric, c, num_classes, argmax=True)[source]

Returns an evaluator that calculates the metric and selects its value for channel c.

Example

>>> from elektronn3.training import metrics
>>> num_classes = 5  # Example. Depends on model and data set
>>> # Metric evaluator dict that registers DSCs of all output channels.
>>> # You can pass it to elektronn3.training.Trainer as the ``valid_metrics``
>>> #  argument to make it log these values.
>>> dsc_evaluators = {
...    f'val_DSC_c{c}': channel_metric(
...        metrics.dice_coefficient,
...        c=c, num_classes=num_classes
...    )
...    for c in range(num_classes)
... }
elektronn3.training.metrics.confusion_matrix(target, pred, num_classes=2, dtype=torch.float32, device=torch.device, nan_when_empty=True, ignore=None)[source]

Calculate per-class confusion matrix.

Uses an LRU cache, so subsequent calls with the same arguments are very cheap.

Parameters:
  • pred (LongTensor) – Tensor with predicted class values

  • target (LongTensor) – Ground truth tensor with true class values

  • num_classes (int) – Number of classes that the target can assume. E.g. for binary classification, this is 2. Classes are expected to start at 0

  • dtype (dtype) – torch.dtype to be used for calculation and output. torch.float32 is used as default because it is robust against overflows and can be used directly in true divisions without re-casting.

  • device (device) – PyTorch device on which to store the confusion matrix

  • nan_when_empty (bool) – If True (default), the confusion matrix will be filled with NaN values for each channel of which there are no positive entries in the target tensor.

  • ignore (Optional[int]) – Index to be ignored for cm calculation

Return type:

Tensor

Returns:

Confusion matrix cm, with shape (num_classes, 4), where each row cm[c] contains (in this order) the count of - true positives - true negatives - false positives - false negatives of pred w.r.t. target and class c.

E.g. cm[1][2] contains the number of false positive predictions of class 1. If nan_when_empty is enabled and there are no positive elements of class 1 in target, cm[1] will instead be filled with NaN values.

elektronn3.training.metrics.dice_coefficient(target, pred, num_classes=2, mean=True, ignore=None)[source]

Sørensen–Dice coefficient a.k.a. DSC a.k.a. F1 score (in %)

elektronn3.training.metrics.iou(target, pred, num_classes=2, mean=True, ignore=None)[source]

IoU (Intersection over Union) a.k.a. IU a.k.a. Jaccard index (in %)

elektronn3.training.metrics.precision(target, pred, num_classes=2, mean=True, ignore=None)[source]

Precision metric (in %)

elektronn3.training.metrics.recall(target, pred, num_classes=2, mean=True, ignore=None)[source]

Recall metric a.k.a. sensitivity a.k.a. hit rate (in %)