elektronn3.training.metrics module¶
Metrics and tools for evaluating neural network predictions
References:
https://stats.stackexchange.com/questions/273537/f1-dice-score-vs-iou
http://scikit-learn.org/stable/modules/model_evaluation.html
Note
sklearn.metrics
has a lot of alternative implementations that can be
compared with these here and could be used as inspiration for future work
(http://scikit-learn.org/stable/modules/classes.html#classification-metrics).
For example, to get the equivalent output to
elektronn3.training.metrics.recall(target, pred, num_classes=2, mean=False) / 100
,
from scikit-learn, you can compute
sklearn.metrics.recall_score(target.view(-1).cpu().numpy(), pred.view(-1).cpu().numpy(), average=None).astype(np.float32)
.
For most metrics, we don’t use scikit-learn directly in this module for performance reasons:
PyTorch allows us to calculate metrics directly on GPU
We LRU-cache confusion matrices for cheap calculation of multiple metrics
-
class
elektronn3.training.metrics.
Evaluator
(metric_fn, index=None, ignore=None)[source]¶ Bases:
object
-
elektronn3.training.metrics.
accuracy
(target, pred, num_classes=2, mean=True, ignore=None)[source]¶ Accuracy metric (in %)
-
elektronn3.training.metrics.
auroc
(target, probs, mean=True)[source]¶ Area under Curve (AuC) of the ROC curve (in %).
Note
This implementation uses scikit-learn on the CPU to do the heavy lifting, so it’s relatively slow (one call can take about 1 second for typical inputs).
-
elektronn3.training.metrics.
average_precision
(target, probs, mean=True)[source]¶ Average precision (AP) metric based on PR curves (in %).
Note
This implementation uses scikit-learn on the CPU to do the heavy lifting, so it’s relatively slow (one call can take about 1 second for typical inputs).
-
elektronn3.training.metrics.
channel_metric
(metric, c, num_classes, argmax=True)[source]¶ Returns an evaluator that calculates the
metric
and selects its value for channelc
.Example
>>> from elektronn3.training import metrics >>> num_classes = 5 # Example. Depends on model and data set >>> # Metric evaluator dict that registers DSCs of all output channels. >>> # You can pass it to elektronn3.training.Trainer as the ``valid_metrics`` >>> # argument to make it log these values. >>> dsc_evaluators = { ... f'val_DSC_c{c}': channel_metric( ... metrics.dice_coefficient, ... c=c, num_classes=num_classes ... ) ... for c in range(num_classes) ... }
-
elektronn3.training.metrics.
confusion_matrix
(target, pred, num_classes=2, dtype=torch.float32, device=torch.device, nan_when_empty=True, ignore=None)[source]¶ Calculate per-class confusion matrix.
Uses an LRU cache, so subsequent calls with the same arguments are very cheap.
- Parameters
pred (
LongTensor
) – Tensor with predicted class valuestarget (
LongTensor
) – Ground truth tensor with true class valuesnum_classes (
int
) – Number of classes that the target can assume. E.g. for binary classification, this is 2. Classes are expected to start at 0dtype (
dtype
) –torch.dtype
to be used for calculation and output.torch.float32
is used as default because it is robust against overflows and can be used directly in true divisions without re-casting.device (
device
) – PyTorch device on which to store the confusion matrixnan_when_empty (
bool
) – IfTrue
(default), the confusion matrix will be filled with NaN values for each channel of which there are no positive entries in thetarget
tensor.ignore (
Optional
[int
]) – Index to be ignored for cm calculation
- Return type
Tensor
- Returns
Confusion matrix
cm
, with shape(num_classes, 4)
, where each rowcm[c]
contains (in this order) the count of - true positives - true negatives - false positives - false negatives ofpred
w.r.t.target
and classc
.E.g.
cm[1][2]
contains the number of false positive predictions of class1
. Ifnan_when_empty
is enabled and there are no positive elements of class1
intarget
,cm[1]
will instead be filled with NaN values.
-
elektronn3.training.metrics.
dice_coefficient
(target, pred, num_classes=2, mean=True, ignore=None)[source]¶ Sørensen–Dice coefficient a.k.a. DSC a.k.a. F1 score (in %)
-
elektronn3.training.metrics.
iou
(target, pred, num_classes=2, mean=True, ignore=None)[source]¶ IoU (Intersection over Union) a.k.a. IU a.k.a. Jaccard index (in %)