Predictor(model, state_dict_src=None, device=None, batch_size=None, tile_shape=None, overlap_shape=None, offset=None, out_shape=None, out_dtype=None, float16=False, apply_softmax=True, transform=None, augmentations=None, strict_shapes=False, apply_argmax=False, argmax_with_threshold=None, verbose=False, report_inp_stats=False)¶
Class to perform inference using a
torch.nn.Moduleobject either passed directly or loaded from a file.
None, input tensors are fed directly into the
model(best for scalar predictions, medium-sized 2D images or very small 3D images). If you define
overlap_shape, these are used to slice a large input into smaller overlapping tiles and perform predictions on these tiles independently and later put the output tiles together into one dense tensor without overlap again. Use this features if your model has spatially interpretable (dense) outputs and if passing one input sample to the
modelwould result in an out-of-memory error. For more details on this tiling mode, see
Network model to be used for inference. The model can be passed as an
torch.nn.Module, or as a path to either a model file or to an elektronn3 save directory:
torch.nn.Moduleobject, it is used directly.
modelis a path (string) to a serialized TorchScript module (.pts), it is loaded from the file and mapped to the specified
modelis a path (string) to a pickled PyTorch module (.pt) (not a pickled
state_dict), it is loaded from the file and mapped to the specified
None]) – Path to
state_dictfile (.pth) or loaded
None. If not
modelis replaced with it.
None]) – Device to run the inference on. Can be a
torch.deviceor a string like
'cuda:0'etc. If not specified (
None), available GPUs are automatically used; the CPU is used as a fallback if no GPUs can be found.
int]) – Maximum batch size with which to perform inference. In general, a higher
batch_sizewill give you higher prediction speed, but prediction will consume more GPU memory. Reduce the
batch_sizeif you run out of memory. If this is
None(default), the input batch size is used as the prediction batch size.
int, …]]) – Spatial shape of the output tiles to use for inference. The spatial shape of the input tensors has to be divisible by the
int, …]]) –
Spatial shape of the overlap by which input tiles are extended w.r.t. the
tile_shapeof the resulting output tiles. The
overlap_shapeshould be close to the effective receptive field of the network architecture that’s used for inference. Note that
tile_shape + 2 * overlapneeds to be a valid input shape for the inference network architecture, so depending on your network architecture (especially pooling layers and strides), you might need to adjust your
overlap_shape. If your inference fails due to shape issues, as a rule of thumb, try adjusting your
tile_shape + 2 * overlapis divisible by 16 or 32.
offset(see below) is not
overlap_shapecan’t be specified but it is configured automatically.
int, …]]) – Shape of the offset by which each the output tiles are smaller than the input tiles on each side. This applies for networks using valid convolutions. If
overlap_shape(see above) can’t be specified but is configured automatically.
int, …]]) –
Expected shape of the output tensor. It doesn’t just refer to spatial shape, but to the actual tensor shape of one sample, including the channel dimension C, but excluding the batch dimension N. Note:
model(inp)is never actually executed if tiling is used –
out_shapeis merely used to pre-allocate the output tensor so it can be filled later. If you know how many channels your model output has (
out_channels) and if your model preserves spatial shape, you can easily calculate
out_shapeyourself as follows:
>>> out_channels: int = ? # E.g. for binary classification it's 2 >>> out_shape = (out_channels, *inp.shape[2:])
dtype]) – torch dtype that the output will be cast to
bool) – If
True, deploy the model in float16 (half) precision.
bool) – If
True(default), a softmax operator is automatically appended to the model, in order to get probability tensors as inference outputs from networks that don’t already apply softmax.
bool) – If
True, the argmax of the model output is computed and returned instead of the class score tensor. This can be used for classification if you are only interested in the final argmax classification. This option can speed up predictions. Note that since argmax is not influenced by softmax,
apply_softmaxcan be safely disabled if
True, even if the model was trained with a softmax loss.
Transformation function to be applied to inputs before performing inference. The primary use of this is for normalization. Make sure to use the same normalization parameters for inference as the ones that were used for training of the
elektronn3.data.transforms. for some implementations. For pure input normalization you can use this template:
>>> from elektronn3.data import transforms >>> # m, s are mean, std of the inputs the model was trained on >>> transform = transforms.Normalize(mean=m, std=s)
None]) – List of test-time augmentations or integer that specifies the number of different flips to be performed as test- time augmentations.
bool) – If
False(default), force the
output_shapeto be a multiple of the
tile_shapeby padding the input. This allows for greater flexibility of the
tile_shapebut potentially wastes more computation (the padded region will be passed into the model but will later be discarded from the output tensor). If
True, incompatible shapes will result in an error.
bool) – If
True, report inference speed.
>>> model = nn.Sequential( ... nn.Conv2d(5, 32, 3, padding=1), nn.ReLU(), ... nn.Conv2d(32, 2, 1)) >>> inp = np.random.randn(2, 5, 10, 10) >>> predictor = Predictor(model) >>> out = predictor.predict(inp) >>> assert np.all(np.array(out.shape) == np.array([2, 2, 10, 10]))
Perform prediction on
inpand return prediction.
Tensor]) – Input data, e.g. of shape (N, C, H, W). Can be an
torch.Tensor. Note that
inpis automatically converted to the specified
torch.float32) before inference.
- Return type
Set state dict of a model.
Also works with
tiled_apply(func, inp, tile_shape, overlap_shape, offset, out_shape, out_dtype=None, argmax_with_threshold=None, verbose=False)¶
Splits a tensor into overlapping tiles and applies a function on them independently.
Each tile of the output results from applying a callable
funcon an input tile which is sliced from a region that has the same center but a larger extent (overlapping with other input regions in the vicinity). Input tensors are also padded with zeros at the boundaries according to the
overlap_shapeto enable consistent tile shapes.
The overlapping behavior prevents imprecisions of CNNs (and image processing algorithms in general) that appear near the boundaries of inner tiles when applying them on a tiled representation of the input.
By default this function assumes that
inp.shape[2:] == func(inp).shape[2:], i.e. that the function keeps the spatial shape unchanged. If
funcreduces the spatial shape (e.g. by performing valid convolutions) and its output is centered w.r.t. the input, you should specify this shape offset in the
offsetparameter. This is the same offset that
It can run on GPU or CPU transparently, depending on the device that
inpis allocated on.
Although this function is mainly intended for the purpose of neural network inference,
funcdoesn’t have to be a neural network but can be any
Callable[[torch.Tensor], torch.Tensor]that operates on n-dimensional image data of shape (N, C, …) and preserves spatial shape or has a constant
offset. (“…” is a placeholder for the spatial dimensions, so for example H(eight) and W(idth).)
Tensor]) – Function to be applied on input tiles. Usually this is a neural network model.
Tensor) – Input tensor, usually of shape (N, C, [D,], H, W). n-dimensional tensors of shape (N, C, …) are supported.
int]) – Spatial shape of the output tiles to use for inference.
int]) – Spatial shape of the overlap by which input tiles are extended w.r.t. the output
Determines the offset by which the output contents are shifted w.r.t. the inputs by
func. This should generally be set to half the spatial shape difference between inputs and outputs:
>>> in_sh = np.array(inp.shape[2:]) >>> out_sh = np.array(func(inp).shape[2:]) >>> offset = (in_sh - out_sh) // 2
int]) – Expected shape of the output tensor that would result from applying
func(inp).shape). It doesn’t just refer to spatial shape, but to the actual tensor shape including N and C dimensions. Note:
func(inp)is never actually executed –
out_shapeis merely used to pre-allocate the output tensor so it can be filled later.
torch.dtypethat the output will be cast to.
bool) – If
True, a progress bar will be shown while iterating over the tiles.
- Return type
Output tensor, as a torch tensor of the same shape as the input tensor.