elektronn3.inference.inference module¶

class
elektronn3.inference.inference.
Predictor
(model, state_dict_src=None, device=None, batch_size=None, tile_shape=None, overlap_shape=None, offset=None, out_shape=None, out_dtype=None, float16=False, apply_softmax=True, transform=None, augmentations=None, strict_shapes=False, apply_argmax=False, argmax_with_threshold=None, verbose=False, report_inp_stats=False)[source]¶ Bases:
object
Class to perform inference using a
torch.nn.Module
object either passed directly or loaded from a file.If both
tile_shape
andoverlap_shape
areNone
, input tensors are fed directly into themodel
(best for scalar predictions, mediumsized 2D images or very small 3D images). If you definetile_shape
andoverlap_shape
, these are used to slice a large input into smaller overlapping tiles and perform predictions on these tiles independently and later put the output tiles together into one dense tensor without overlap again. Use this features if your model has spatially interpretable (dense) outputs and if passing one input sample to themodel
would result in an outofmemory error. For more details on this tiling mode, seeelektronn3.inference.inference.tiled_apply()
. Parameters
model (
Union
[Module
,str
]) –Network model to be used for inference. The model can be passed as an
torch.nn.Module
, or as a path to either a model file or to an elektronn3 save directory:If
model
is atorch.nn.Module
object, it is used directly.If
model
is a path (string) to a serialized TorchScript module (.pts), it is loaded from the file and mapped to the specifieddevice
.If
model
is a path (string) to a pickled PyTorch module (.pt) (not a pickledstate_dict
), it is loaded from the file and mapped to the specifieddevice
as well.
state_dict_src (
Union
[str
,dict
,None
]) – Path tostate_dict
file (.pth) or loadedstate_dict
orNone
. If notNone
, thestate_dict
of themodel
is replaced with it.device (
Union
[device
,str
,None
]) – Device to run the inference on. Can be atorch.device
or a string like'cpu'
,'cuda:0'
etc. If not specified (None
), available GPUs are automatically used; the CPU is used as a fallback if no GPUs can be found.batch_size (
Optional
[int
]) – Maximum batch size with which to perform inference. In general, a higherbatch_size
will give you higher prediction speed, but prediction will consume more GPU memory. Reduce thebatch_size
if you run out of memory. If this isNone
(default), the input batch size is used as the prediction batch size.tile_shape (
Optional
[Tuple
[int
, …]]) – Spatial shape of the output tiles to use for inference. The spatial shape of the input tensors has to be divisible by thetile_shape
.overlap_shape (
Optional
[Tuple
[int
, …]]) –Spatial shape of the overlap by which input tiles are extended w.r.t. the
tile_shape
of the resulting output tiles. Theoverlap_shape
should be close to the effective receptive field of the network architecture that’s used for inference. Note thattile_shape + 2 * overlap
needs to be a valid input shape for the inference network architecture, so depending on your network architecture (especially pooling layers and strides), you might need to adjust youroverlap_shape
. If your inference fails due to shape issues, as a rule of thumb, try adjusting youroverlap_shape
so thattile_shape + 2 * overlap
is divisible by 16 or 32.If
offset
(see below) is notNone
,overlap_shape
can’t be specified but it is configured automatically.offset (
Optional
[Tuple
[int
, …]]) – Shape of the offset by which each the output tiles are smaller than the input tiles on each side. This applies for networks using valid convolutions. Ifoffset
is specified,overlap_shape
(see above) can’t be specified but is configured automatically.out_shape (
Optional
[Tuple
[int
, …]]) –Expected shape of the output tensor. It doesn’t just refer to spatial shape, but to the actual tensor shape of one sample, including the channel dimension C, but excluding the batch dimension N. Note:
model(inp)
is never actually executed if tiling is used –out_shape
is merely used to preallocate the output tensor so it can be filled later. If you know how many channels your model output has (out_channels
) and if your model preserves spatial shape, you can easily calculateout_shape
yourself as follows:>>> out_channels: int = ? # E.g. for binary classification it's 2 >>> out_shape = (out_channels, *inp.shape[2:])
out_dtype (
Optional
[dtype
]) – torch dtype that the output will be cast tofloat16 (
bool
) – IfTrue
, deploy the model in float16 (half) precision.apply_softmax (
bool
) – IfTrue
(default), a softmax operator is automatically appended to the model, in order to get probability tensors as inference outputs from networks that don’t already apply softmax.apply_argmax (
bool
) – IfTrue
, the argmax of the model output is computed and returned instead of the class score tensor. This can be used for classification if you are only interested in the final argmax classification. This option can speed up predictions. Note that since argmax is not influenced by softmax,apply_softmax
can be safely disabled ifapply_argmax
isTrue
, even if the model was trained with a softmax loss.transform (
Optional
[Callable
[[ndarray
,Optional
[ndarray
]],Tuple
[ndarray
,Optional
[ndarray
]]]]) –Transformation function to be applied to inputs before performing inference. The primary use of this is for normalization. Make sure to use the same normalization parameters for inference as the ones that were used for training of the
model
. Seeelektronn3.data.transforms
. for some implementations. For pure input normalization you can use this template:>>> from elektronn3.data import transforms >>> # m, s are mean, std of the inputs the model was trained on >>> transform = transforms.Normalize(mean=m, std=s)
augmentations (
Union
[int
,Sequence
,None
]) – List of testtime augmentations or integer that specifies the number of different flips to be performed as test time augmentations.strict_shapes (
bool
) – IfFalse
(default), force theoutput_shape
to be a multiple of thetile_shape
by padding the input. This allows for greater flexibility of thetile_shape
but potentially wastes more computation (the padded region will be passed into the model but will later be discarded from the output tensor). IfTrue
, incompatible shapes will result in an error.verbose (
bool
) – IfTrue
, report inference speed.report_inp_stats (
bool
) –
Examples
>>> model = nn.Sequential( ... nn.Conv2d(5, 32, 3, padding=1), nn.ReLU(), ... nn.Conv2d(32, 2, 1)) >>> inp = np.random.randn(2, 5, 10, 10) >>> predictor = Predictor(model) >>> out = predictor.predict(inp) >>> assert np.all(np.array(out.shape) == np.array([2, 2, 10, 10]))

predict
(inp)[source]¶ Perform prediction on
inp
and return prediction. Parameters
inp (
Union
[ndarray
,Tensor
]) – Input data, e.g. of shape (N, C, H, W). Can be annp.ndarray
or atorch.Tensor
. Note thatinp
is automatically converted to the specifieddtype
(default:torch.float32
) before inference. Return type
Tensor
 Returns
Model output

elektronn3.inference.inference.
set_state_dict
(model, state_dict)[source]¶ Set state dict of a model.
Also works with
torch.nn.DataParallel
models.

elektronn3.inference.inference.
tiled_apply
(func, inp, tile_shape, overlap_shape, offset, out_shape, out_dtype=None, argmax_with_threshold=None, verbose=False)[source]¶ Splits a tensor into overlapping tiles and applies a function on them independently.
Each tile of the output results from applying a callable
func
on an input tile which is sliced from a region that has the same center but a larger extent (overlapping with other input regions in the vicinity). Input tensors are also padded with zeros at the boundaries according to theoverlap_shape
to enable consistent tile shapes.The overlapping behavior prevents imprecisions of CNNs (and image processing algorithms in general) that appear near the boundaries of inner tiles when applying them on a tiled representation of the input.
By default this function assumes that
inp.shape[2:] == func(inp).shape[2:]
, i.e. that the function keeps the spatial shape unchanged. Iffunc
reduces the spatial shape (e.g. by performing valid convolutions) and its output is centered w.r.t. the input, you should specify this shape offset in theoffset
parameter. This is the same offset thatelektronn3.data.cnndata.PatchCreator
expects.It can run on GPU or CPU transparently, depending on the device that
inp
is allocated on.Although this function is mainly intended for the purpose of neural network inference,
func
doesn’t have to be a neural network but can be anyCallable[[torch.Tensor], torch.Tensor]
that operates on ndimensional image data of shape (N, C, …) and preserves spatial shape or has a constantoffset
. (“…” is a placeholder for the spatial dimensions, so for example H(eight) and W(idth).) Parameters
func (
Callable
[[Tensor
],Tensor
]) – Function to be applied on input tiles. Usually this is a neural network model.inp (
Tensor
) – Input tensor, usually of shape (N, C, [D,], H, W). ndimensional tensors of shape (N, C, …) are supported.tile_shape (
Sequence
[int
]) – Spatial shape of the output tiles to use for inference.overlap_shape (
Sequence
[int
]) – Spatial shape of the overlap by which input tiles are extended w.r.t. the outputtile_shape
.offset (
Optional
[Sequence
[int
]]) –Determines the offset by which the output contents are shifted w.r.t. the inputs by
func
. This should generally be set to half the spatial shape difference between inputs and outputs:>>> in_sh = np.array(inp.shape[2:]) >>> out_sh = np.array(func(inp).shape[2:]) >>> offset = (in_sh  out_sh) // 2
out_shape (
Sequence
[int
]) – Expected shape of the output tensor that would result from applyingfunc
toinp
(func(inp).shape
). It doesn’t just refer to spatial shape, but to the actual tensor shape including N and C dimensions. Note:func(inp)
is never actually executed –out_shape
is merely used to preallocate the output tensor so it can be filled later.out_dtype (
Optional
[dtype
]) –torch.dtype
that the output will be cast to.verbose (
bool
) – IfTrue
, a progress bar will be shown while iterating over the tiles.argmax_with_threshold (
Optional
[float
]) –
 Return type
Tensor
 Returns
Output tensor, as a torch tensor of the same shape as the input tensor.