elektronn3.inference.inference module

class elektronn3.inference.inference.Argmax(*args: Any, **kwargs: Any)[source]

Bases: torch.nn.

class elektronn3.inference.inference.FlipAugment(dims)[source]

Bases: object

class elektronn3.inference.inference.Predictor(model, state_dict_src=None, device=None, batch_size=None, tile_shape=None, overlap_shape=None, offset=None, out_shape=None, out_dtype=None, float16=False, apply_softmax=True, transform=None, augmentations=None, strict_shapes=False, apply_argmax=False, argmax_with_threshold=None, verbose=False, report_inp_stats=False)[source]

Bases: object

Class to perform inference using a torch.nn.Module object either passed directly or loaded from a file.

If both tile_shape and overlap_shape are None, input tensors are fed directly into the model (best for scalar predictions, medium-sized 2D images or very small 3D images). If you define tile_shape and overlap_shape, these are used to slice a large input into smaller overlapping tiles and perform predictions on these tiles independently and later put the output tiles together into one dense tensor without overlap again. Use this features if your model has spatially interpretable (dense) outputs and if passing one input sample to the model would result in an out-of-memory error. For more details on this tiling mode, see elektronn3.inference.inference.tiled_apply().

  • model (Union[Module, str]) –

    Network model to be used for inference. The model can be passed as an torch.nn.Module, or as a path to either a model file or to an elektronn3 save directory:

    • If model is a torch.nn.Module object, it is used directly.

    • If model is a path (string) to a serialized TorchScript module (.pts), it is loaded from the file and mapped to the specified device.

    • If model is a path (string) to a pickled PyTorch module (.pt) (not a pickled state_dict), it is loaded from the file and mapped to the specified device as well.

  • state_dict_src (Union[str, dict, None]) – Path to state_dict file (.pth) or loaded state_dict or None. If not None, the state_dict of the model is replaced with it.

  • device (Union[device, str, None]) – Device to run the inference on. Can be a torch.device or a string like 'cpu', 'cuda:0' etc. If not specified (None), available GPUs are automatically used; the CPU is used as a fallback if no GPUs can be found.

  • batch_size (Optional[int]) – Maximum batch size with which to perform inference. In general, a higher batch_size will give you higher prediction speed, but prediction will consume more GPU memory. Reduce the batch_size if you run out of memory. If this is None (default), the input batch size is used as the prediction batch size.

  • tile_shape (Optional[Tuple[int, …]]) – Spatial shape of the output tiles to use for inference. The spatial shape of the input tensors has to be divisible by the tile_shape.

  • overlap_shape (Optional[Tuple[int, …]]) –

    Spatial shape of the overlap by which input tiles are extended w.r.t. the tile_shape of the resulting output tiles. The overlap_shape should be close to the effective receptive field of the network architecture that’s used for inference. Note that tile_shape + 2 * overlap needs to be a valid input shape for the inference network architecture, so depending on your network architecture (especially pooling layers and strides), you might need to adjust your overlap_shape. If your inference fails due to shape issues, as a rule of thumb, try adjusting your overlap_shape so that tile_shape + 2 * overlap is divisible by 16 or 32.

    If offset (see below) is not None, overlap_shape can’t be specified but it is configured automatically.

  • offset (Optional[Tuple[int, …]]) – Shape of the offset by which each the output tiles are smaller than the input tiles on each side. This applies for networks using valid convolutions. If offset is specified, overlap_shape (see above) can’t be specified but is configured automatically.

  • out_shape (Optional[Tuple[int, …]]) –

    Expected shape of the output tensor. It doesn’t just refer to spatial shape, but to the actual tensor shape of one sample, including the channel dimension C, but excluding the batch dimension N. Note: model(inp) is never actually executed if tiling is used – out_shape is merely used to pre-allocate the output tensor so it can be filled later. If you know how many channels your model output has (out_channels) and if your model preserves spatial shape, you can easily calculate out_shape yourself as follows:

    >>> out_channels: int = ?  # E.g. for binary classification it's 2
    >>> out_shape = (out_channels, *inp.shape[2:])

  • out_dtype (Optional[dtype]) – torch dtype that the output will be cast to

  • float16 (bool) – If True, deploy the model in float16 (half) precision.

  • apply_softmax (bool) – If True (default), a softmax operator is automatically appended to the model, in order to get probability tensors as inference outputs from networks that don’t already apply softmax.

  • apply_argmax (bool) – If True, the argmax of the model output is computed and returned instead of the class score tensor. This can be used for classification if you are only interested in the final argmax classification. This option can speed up predictions. Note that since argmax is not influenced by softmax, apply_softmax can be safely disabled if apply_argmax is True, even if the model was trained with a softmax loss.

  • transform (Optional[Callable[[ndarray, Optional[ndarray]], Tuple[ndarray, Optional[ndarray]]]]) –

    Transformation function to be applied to inputs before performing inference. The primary use of this is for normalization. Make sure to use the same normalization parameters for inference as the ones that were used for training of the model. See elektronn3.data.transforms. for some implementations. For pure input normalization you can use this template:

    >>> from elektronn3.data import transforms
    >>> # m, s are mean, std of the inputs the model was trained on
    >>> transform = transforms.Normalize(mean=m, std=s)

  • augmentations (Union[int, Sequence, None]) – List of test-time augmentations or integer that specifies the number of different flips to be performed as test- time augmentations.

  • strict_shapes (bool) – If False (default), force the output_shape to be a multiple of the tile_shape by padding the input. This allows for greater flexibility of the tile_shape but potentially wastes more computation (the padded region will be passed into the model but will later be discarded from the output tensor). If True, incompatible shapes will result in an error.

  • verbose (bool) – If True, report inference speed.

  • report_inp_stats (bool) –


>>> model = nn.Sequential(
...     nn.Conv2d(5, 32, 3, padding=1), nn.ReLU(),
...     nn.Conv2d(32, 2, 1))
>>> inp = np.random.randn(2, 5, 10, 10)
>>> predictor = Predictor(model)
>>> out = predictor.predict(inp)
>>> assert np.all(np.array(out.shape) == np.array([2, 2, 10, 10]))

Perform prediction on inp and return prediction.


inp (Union[ndarray, Tensor]) – Input data, e.g. of shape (N, C, H, W). Can be an np.ndarray or a torch.Tensor. Note that inp is automatically converted to the specified dtype (default: torch.float32) before inference.

Return type



Model output

elektronn3.inference.inference.set_state_dict(model, state_dict)[source]

Set state dict of a model.

Also works with torch.nn.DataParallel models.

elektronn3.inference.inference.tiled_apply(func, inp, tile_shape, overlap_shape, offset, out_shape, verbose=False)[source]

Splits a tensor into overlapping tiles and applies a function on them independently.

Each tile of the output results from applying a callable func on an input tile which is sliced from a region that has the same center but a larger extent (overlapping with other input regions in the vicinity). Input tensors are also padded with zeros at the boundaries according to the overlap_shape to enable consistent tile shapes.

The overlapping behavior prevents imprecisions of CNNs (and image processing algorithms in general) that appear near the boundaries of inner tiles when applying them on a tiled representation of the input.

By default this function assumes that inp.shape[2:] == func(inp).shape[2:], i.e. that the function keeps the spatial shape unchanged. If func reduces the spatial shape (e.g. by performing valid convolutions) and its output is centered w.r.t. the input, you should specify this shape offset in the offset parameter. This is the same offset that elektronn3.data.cnndata.PatchCreator expects.

It can run on GPU or CPU transparently, depending on the device that inp is allocated on.

Although this function is mainly intended for the purpose of neural network inference, func doesn’t have to be a neural network but can be any Callable[[torch.Tensor], torch.Tensor] that operates on n-dimensional image data of shape (N, C, …) and preserves spatial shape or has a constant offset. (”…” is a placeholder for the spatial dimensions, so for example H(eight) and W(idth).)

  • func (Callable[[Tensor], Tensor]) – Function to be applied on input tiles. Usually this is a neural network model.

  • inp (Tensor) – Input tensor, usually of shape (N, C, [D,], H, W). n-dimensional tensors of shape (N, C, …) are supported.

  • tile_shape (Sequence[int]) – Spatial shape of the output tiles to use for inference.

  • overlap_shape (Sequence[int]) – Spatial shape of the overlap by which input tiles are extended w.r.t. the output tile_shape.

  • offset (Optional[Sequence[int]]) –

    Determines the offset by which the output contents are shifted w.r.t. the inputs by func. This should generally be set to half the spatial shape difference between inputs and outputs:

    >>> in_sh = np.array(inp.shape[2:])
    >>> out_sh = np.array(func(inp).shape[2:])
    >>> offset = (in_sh - out_sh) // 2

  • out_shape (Sequence[int]) – Expected shape of the output tensor that would result from applying func to inp (func(inp).shape). It doesn’t just refer to spatial shape, but to the actual tensor shape including N and C dimensions. Note: func(inp) is never actually executed – out_shape is merely used to pre-allocate the output tensor so it can be filled later.

  • verbose (bool) – If True, a progress bar will be shown while iterating over the tiles.

Return type



Output tensor, as a torch tensor of the same shape as the input tensor.