Source code for elektronn3.inference.inference

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert, Martin Drawitsch
import copy
import itertools
import logging
import os
import time
import zipfile
from collections import OrderedDict
from typing import Optional, Tuple, Union, Callable, Sequence
from pathlib import Path

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

from elektronn3.data import utils


# TODO: It's confusing that tiled_apply expects out_shape to include the N dim, but
#     Predictor has a parameter with the same name but which doesn't include N.

logger = logging.getLogger('elektronn3log')

# Alias for type hinting
Transform = Callable[
    [np.ndarray, Optional[np.ndarray]],
    Tuple[np.ndarray, Optional[np.ndarray]]
]


def _extend_nc(spatial_slice: Sequence[slice]) -> Tuple[slice, ...]:
    """Extend a spatial slice ([D,] H, W) to also include the non-spatial (N, C) dims."""
    # Slice everything in N and C dims (equivalent to [:, :] in direct notation)
    nonspatial_slice = [slice(None)] * 2
    return tuple(nonspatial_slice + list(spatial_slice))


# TODO Fix and document out_shape change for argmax outputs

[docs] def tiled_apply( func: Callable[[torch.Tensor], torch.Tensor], inp: torch.Tensor, tile_shape: Sequence[int], overlap_shape: Sequence[int], offset: Optional[Sequence[int]], out_shape: Sequence[int], verbose: bool = False ) -> torch.Tensor: """Splits a tensor into overlapping tiles and applies a function on them independently. Each tile of the output results from applying a callable ``func`` on an input tile which is sliced from a region that has the same center but a larger extent (overlapping with other input regions in the vicinity). Input tensors are also padded with zeros at the boundaries according to the ``overlap_shape`` to enable consistent tile shapes. The overlapping behavior prevents imprecisions of CNNs (and image processing algorithms in general) that appear near the boundaries of inner tiles when applying them on a tiled representation of the input. By default this function assumes that ``inp.shape[2:] == func(inp).shape[2:]``, i.e. that the function keeps the spatial shape unchanged. If ``func`` reduces the spatial shape (e.g. by performing valid convolutions) and its output is centered w.r.t. the input, you should specify this shape offset in the ``offset`` parameter. This is the same offset that :py:class:`elektronn3.data.cnndata.PatchCreator` expects. It can run on GPU or CPU transparently, depending on the device that ``inp`` is allocated on. Although this function is mainly intended for the purpose of neural network inference, ``func`` doesn't have to be a neural network but can be any ``Callable[[torch.Tensor], torch.Tensor]`` that operates on n-dimensional image data of shape (N, C, ...) and preserves spatial shape or has a constant ``offset``. ("..." is a placeholder for the spatial dimensions, so for example H(eight) and W(idth).) Args: func: Function to be applied on input tiles. Usually this is a neural network model. inp: Input tensor, usually of shape (N, C, [D,], H, W). n-dimensional tensors of shape (N, C, ...) are supported. tile_shape: Spatial shape of the output tiles to use for inference. overlap_shape: Spatial shape of the overlap by which input tiles are extended w.r.t. the output ``tile_shape``. offset: Determines the offset by which the output contents are shifted w.r.t. the inputs by ``func``. This should generally be set to half the spatial shape difference between inputs and outputs: >>> in_sh = np.array(inp.shape[2:]) >>> out_sh = np.array(func(inp).shape[2:]) >>> offset = (in_sh - out_sh) // 2 out_shape: Expected shape of the output tensor that would result from applying ``func`` to ``inp`` (``func(inp).shape``). It doesn't just refer to spatial shape, but to the actual tensor shape including N and C dimensions. Note: ``func(inp)`` is never actually executed – ``out_shape`` is merely used to pre-allocate the output tensor so it can be filled later. verbose: If ``True``, a progress bar will be shown while iterating over the tiles. Returns: Output tensor, as a torch tensor of the same shape as the input tensor. """ if not (inp.dim() - 2 == len(tile_shape) == len(overlap_shape)): raise ValueError( f'ndims of tile shape ({len(tile_shape)}) and overlap shape ' f'({len(overlap_shape)}) don\'t match input shape ndim - 2' f'({inp.dim() - 2}).' ) if not np.all(np.mod(out_shape[2:], tile_shape) == 0): raise ValueError( f'spatial out shape[2:] {tuple(out_shape[2:])} has to be divisible ' f'by tile_shape {tile_shape}.' ) if offset is not None: offset = np.array(offset) inp_shape = np.array(inp.shape) out = None out_shape = np.array(out_shape) tile_shape = np.array(tile_shape) overlap_shape = np.array(overlap_shape) if not np.array_equal(out_shape[2:], inp_shape[2:]): # input is already padded inp_padded = inp else: # Create padded input with overlap padded_shape = inp_shape + np.array((0, 0, *overlap_shape * 2)) logger.info(f'additional input padding to {padded_shape}') inp_padded = torch.zeros(tuple(padded_shape), dtype=inp.dtype) padslice = _extend_nc( [slice(l, h) for l, h in zip(overlap_shape, padded_shape[2:] - overlap_shape)] ) inp_padded[padslice] = inp crop_low_corner = overlap_shape.copy() crop_high_corner = tile_shape + overlap_shape # Used to crop the output tile to the relevant, unpadded region # that will be written to the final output final_crop_slice = _extend_nc([slice(l, h) for l, h in zip(crop_low_corner, crop_high_corner)]) if offset is not None: # no cropping necessary for valid conv final_crop_slice = None del inp tiles = np.ceil(out_shape[2:] / tile_shape).astype(int) num_tiles = np.prod(tiles) tile_ranges = [range(t) for t in tiles] # TODO: Handle fractional inputshape-to-tile ratio pbar = tqdm( itertools.product(*tile_ranges), 'Predicting', total=num_tiles, disable=not verbose, dynamic_ncols=True ) for tile_pos in pbar: tile_pos = np.array(tile_pos) # Calculate corner coordinates of the current output tile out_low_corner = tile_shape * tile_pos out_high_corner = tile_shape * (tile_pos + 1) # Note: To understand why the input corners are chosen in this # particular way, it helps to draw this on paper in the 1d case, # representing the input tensor as a line and slicing it into # input and output tiles, where input tiles have a certain overlap # (note that input and output coordinates exist in two different # coordinate systems: input corner coordinates are shifted "right" by # the ``overlap_shape`` w.r.t. the output corner coordinate system # due to the initial padding. inp_low_corner = out_low_corner.copy() inp_high_corner = out_high_corner.copy() + 2 * overlap_shape assert np.all(np.less_equal(inp_high_corner, inp_padded.shape[2:])), inp_high_corner # Slice only the current tile region in ([D,] H, W) dims # Slice input with overlap inp_slice = _extend_nc([slice(l, h) for l, h in zip(inp_low_corner, inp_high_corner)]) # Output slice without overlap (this is the region where the current # inference result will be stored) out_slice = _extend_nc([slice(l, h) for l, h in zip(out_low_corner, out_high_corner)]) inp_tile = inp_padded[inp_slice].contiguous() out_tile = func(inp_tile, final_crop_slice) # Slice the relevant tile_shape-sized region out of the model output # so it can be written to the final output # Since out is a CPU tensor, out[out_slice] assignments below implicitly copy data to CPU if out is None: out = torch.empty(out_shape.tolist(), dtype=out_tile.dtype) out[out_slice] = out_tile return out
[docs] class Argmax(nn.Module): def __init__(self, dim=1, unsqueeze=True): super().__init__() self.dim = dim self.unsqueeze = unsqueeze
[docs] def forward(self, x): argmax = torch.argmax(x, self.dim) if self.unsqueeze: # Restore C dim as a workaround for unified slicing pattern in tiled_apply() argmax.unsqueeze_(1) return argmax
[docs] class FlipAugment: def __init__(self, dims): self.dims = tuple(np.array(dims) + 2) # Dim offset to skip (N, C) dims
[docs] def forward(self, inp): return torch.flip(inp, dims=self.dims)
[docs] def backward(self, inp): return self.forward(inp)
# TODO # class Rot90Augment: # def __init__(self, k, dims): # self.k = k # self.dims = dims # # def forward(self, inp): # return torch.rot90(inp, k=self.k, dims=self.dims) # # def backward(self, inp): # return torch.rot90(inp, k=-self.k, dims=self.dims) DEFAULT_AUGMENTATIONS_3D = [ # Flip every dim FlipAugment(dims) for dims in [(0,), (1,), (0, 1), (2,), (0, 2), (1, 2), (0, 1, 2)] ] DEFAULT_AUGMENTATIONS_2D = DEFAULT_AUGMENTATIONS_3D[:3] # Limit flips to first 2 dims
[docs] class Predictor: """Class to perform inference using a ``torch.nn.Module`` object either passed directly or loaded from a file. If both ``tile_shape`` and ``overlap_shape`` are ``None``, input tensors are fed directly into the ``model`` (best for scalar predictions, medium-sized 2D images or very small 3D images). If you define ``tile_shape`` and ``overlap_shape``, these are used to slice a large input into smaller overlapping tiles and perform predictions on these tiles independently and later put the output tiles together into one dense tensor without overlap again. Use this features if your model has spatially interpretable (dense) outputs and if passing one input sample to the ``model`` would result in an out-of-memory error. For more details on this tiling mode, see :py:meth:`elektronn3.inference.inference.tiled_apply()`. Args: model: Network model to be used for inference. The model can be passed as an ``torch.nn.Module``, or as a path to either a model file or to an elektronn3 save directory: - If ``model`` is a ``torch.nn.Module`` object, it is used directly. - If ``model`` is a path (string) to a serialized TorchScript module (.pts), it is loaded from the file and mapped to the specified ``device``. - If ``model`` is a path (string) to a pickled PyTorch module (.pt) (**not** a pickled ``state_dict``), it is loaded from the file and mapped to the specified ``device`` as well. state_dict_src: Path to ``state_dict`` file (.pth) or loaded ``state_dict`` or ``None``. If not ``None``, the ``state_dict`` of the ``model`` is replaced with it. device: Device to run the inference on. Can be a ``torch.device`` or a string like ``'cpu'``, ``'cuda:0'`` etc. If not specified (``None``), available GPUs are automatically used; the CPU is used as a fallback if no GPUs can be found. batch_size: Maximum batch size with which to perform inference. In general, a higher ``batch_size`` will give you higher prediction speed, but prediction will consume more GPU memory. Reduce the ``batch_size`` if you run out of memory. If this is ``None`` (default), the input batch size is used as the prediction batch size. tile_shape: Spatial shape of the output tiles to use for inference. The spatial shape of the input tensors has to be divisible by the ``tile_shape``. overlap_shape: Spatial shape of the overlap by which input tiles are extended w.r.t. the ``tile_shape`` of the resulting output tiles. The ``overlap_shape`` should be close to the effective receptive field of the network architecture that's used for inference. Note that ``tile_shape + 2 * overlap`` needs to be a valid input shape for the inference network architecture, so depending on your network architecture (especially pooling layers and strides), you might need to adjust your ``overlap_shape``. If your inference fails due to shape issues, as a rule of thumb, try adjusting your ``overlap_shape`` so that ``tile_shape + 2 * overlap`` is divisible by 16 or 32. If ``offset`` (see below) is not ``None``, ``overlap_shape`` can't be specified but it is configured automatically. offset: Shape of the offset by which each the output tiles are smaller than the input tiles on each side. This applies for networks using valid convolutions. If ``offset`` is specified, ``overlap_shape`` (see above) can't be specified but is configured automatically. out_shape: Expected shape of the output tensor. It doesn't just refer to spatial shape, but to the actual tensor shape of one sample, including the channel dimension C, but **excluding** the batch dimension N. Note: ``model(inp)`` is never actually executed if tiling is used – ``out_shape`` is merely used to pre-allocate the output tensor so it can be filled later. If you know how many channels your model output has (``out_channels``) and if your model preserves spatial shape, you can easily calculate ``out_shape`` yourself as follows: >>> out_channels: int = ? # E.g. for binary classification it's 2 >>> out_shape = (out_channels, *inp.shape[2:]) out_dtype: torch dtype that the output will be cast to float16: If ``True``, deploy the model in float16 (half) precision. apply_softmax: If ``True`` (default), a softmax operator is automatically appended to the model, in order to get probability tensors as inference outputs from networks that don't already apply softmax. apply_argmax: If ``True``, the argmax of the model output is computed and returned instead of the class score tensor. This can be used for classification if you are only interested in the final argmax classification. This option can speed up predictions. Note that since argmax is not influenced by softmax, ``apply_softmax`` can be safely disabled if ``apply_argmax`` is ``True``, even if the model was trained with a softmax loss. transform: Transformation function to be applied to inputs before performing inference. The primary use of this is for normalization. Make sure to use the same normalization parameters for inference as the ones that were used for training of the ``model``. See :py:mod:`elektronn3.data.transforms`. for some implementations. For pure input normalization you can use this template:: >>> from elektronn3.data import transforms >>> # m, s are mean, std of the inputs the model was trained on >>> transform = transforms.Normalize(mean=m, std=s) augmentations: List of test-time augmentations or integer that specifies the number of different flips to be performed as test- time augmentations. strict_shapes: If ``False`` (default), force the ``output_shape`` to be a multiple of the ``tile_shape`` by padding the input. This allows for greater flexibility of the ``tile_shape`` but potentially wastes more computation (the padded region will be passed into the model but will later be discarded from the output tensor). If ``True``, incompatible shapes will result in an error. verbose: If ``True``, report inference speed. report_inp_stats Examples: >>> model = nn.Sequential( ... nn.Conv2d(5, 32, 3, padding=1), nn.ReLU(), ... nn.Conv2d(32, 2, 1)) >>> inp = np.random.randn(2, 5, 10, 10) >>> predictor = Predictor(model) >>> out = predictor.predict(inp) >>> assert np.all(np.array(out.shape) == np.array([2, 2, 10, 10])) """ def __init__( self, model: Union[nn.Module, str, Path], state_dict_src: Optional[Union[str, dict]] = None, device: Optional[Union[torch.device, str]] = None, batch_size: Optional[int] = None, tile_shape: Optional[Tuple[int, ...]] = None, overlap_shape: Optional[Tuple[int, ...]] = None, offset: Optional[Tuple[int, ...]] = None, out_shape: Optional[Tuple[int, ...]] = None, out_dtype: Optional[torch.dtype] = None, float16: bool = False, apply_softmax: bool = True, transform: Optional[Transform] = None, augmentations: Union[int, Optional[Sequence]] = None, strict_shapes: bool = False, apply_argmax: bool = False, argmax_with_threshold: Optional[float] = None, verbose: bool = False, report_inp_stats: bool = False ): if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'Running on device {device}') elif isinstance(device, str): device = torch.device(device) self.device = device self.batch_size = batch_size self.out_dtype = out_dtype self.float16 = float16 if isinstance(model, Path): model = str(model) if float16 and not isinstance(model, str) and not next(model.parameters()).dtype == torch.float16: # If the model is passed as an object and not already in float16, # we need to deepcopy it because model casting to float16 is only # supported in-place - thus we would downcast it irreversibly also # in the calling scope. model = copy.deepcopy(model) self.dtype = torch.float16 if float16 else torch.float32 self.transform = transform if isinstance(augmentations, int): augmentations = DEFAULT_AUGMENTATIONS_3D[:augmentations] self.augmentations = augmentations self.strict_shapes = strict_shapes self.apply_argmax = apply_argmax self.argmax_with_threshold = argmax_with_threshold self.verbose = verbose self.report_inp_stats = report_inp_stats if isinstance(model, str): if os.path.isfile(model): if model.endswith('.pts'): model = torch.jit.load(model, map_location=device) elif model.endswith('.pt'): model = torch.load(model, map_location=device) else: raise ValueError(f'{model} has an unkown file extension. Supported are .pt and .pts') else: raise ValueError(f'Model path {model} not found.') self.model = model if isinstance(state_dict_src, str): state_dict = torch.load(state_dict_src) if 'model_state_dict' in state_dict: # Handle nested dicts state_dict = state_dict['model_state_dict'] elif isinstance(state_dict_src, dict) or state_dict_src is None: state_dict = state_dict_src else: raise ValueError( '"state_dict_src" has to be either a path to a .pth file (str),' ' a state_dict object (dict) or None.') if state_dict is not None: set_state_dict(model, state_dict) if not apply_softmax and augmentations is not None: raise ValueError('When augmentations are enabled, apply_softmax cannot be False.') if apply_softmax or augmentations is not None: self.model = nn.Sequential(self.model, nn.Softmax(1)) if float16: self.model.half() # This is destructive. float32 params are lost! self.apply_argmax_after_tta = False if apply_argmax or argmax_with_threshold is not None: self.apply_argmax_after_tta = augmentations is not None if not self.apply_argmax_after_tta: # if augmentations are enabled, argmax is applied after augmentations, see _predict argmax_layers = [Argmax(dim=1, unsqueeze=True)] if argmax_with_threshold: argmax_layers = [nn.Threshold(argmax_with_threshold, 0)] + argmax_layers self.model = nn.Sequential(self.model, *argmax_layers) if self.out_dtype is None: self.out_dtype = torch.uint8 self._warn_about_shapes = True self.model.eval() def is_set(array: Tuple[int, ...]): return array is not None and np.any(array) if is_set(overlap_shape) and is_set(offset): raise ValueError( f'overlap_shape={overlap_shape} and offet={offset} are both specified, but this is not supported.\n' 'Either specify overlap_shape (if the spatial shape of inputs and outputs are the same)\n' 'or offset (if the output is smaller).' ) if not is_set(tile_shape): # no tiling assert not (is_set(out_shape) or is_set(overlap_shape) or is_set(offset)), 'If tile_shape is not set, out_shape, overlap_shape and offset should not be set either.' self.enable_tiling = False else: assert is_set(out_shape), 'If tile_shape is set, out_shape is required to be set, too.' self.enable_tiling = True if offset is None: logger.warning( 'Predictor: offset=None -> Estimating offset from forward pass. This can fail or silently lead to incorrect results ' 'and costs an additional forward pass. To avoid this, please set Predictor offset explicitly.' ) offset = utils.calculate_offset(self.model) if np.count_nonzero(offset) == 0: # no valid conv → disable offset offset = None else: offset = np.array(offset) # Set overlap to offset shape because IMO that's the only reasonable choice. overlap_shape = offset out_shape = np.array([*out_shape[:-len(offset)], *(out_shape[-len(offset):] - 2 * offset)]) logger.info(f'Adjusted out_shape: {out_shape}') self.offset = offset self.overlap_shape = np.array(overlap_shape) if overlap_shape is not None else None self.tile_shape = np.array(tile_shape) if tile_shape is not None else None self.out_shape = np.array(out_shape) if out_shape is not None else None @torch.no_grad() def _predict(self, dinp: torch.Tensor, crop_slice=None) -> torch.Tensor: dinp = dinp.to(self.device, dtype=self.dtype) dout = self.model(dinp) if crop_slice is not None: dout = dout[crop_slice] # Else, apply test-time augmentations and take the mean value. # Augmentations are applied directly on the compute device and # intermediate results are stored on-device, so this can increase # GPU memory usage! if self.augmentations is not None: douts = [dout] for aug in self.augmentations: dinp_aug = aug.forward(dinp) dout_aug = self.model(dinp_aug.to(self.device)) dout = aug.backward(dout_aug) if crop_slice: dout = dout[crop_slice] douts.append(dout) douts = torch.stack(douts) dout = torch.mean(douts, dim=0) # if no augmentations, argmax was already applied by the model if self.apply_argmax_after_tta: if self.argmax_with_threshold: dout[dout <= self.argmax_with_threshold] = 0 dout = dout.argmax(dim=1).to(self.out_dtype) dout = dout.to(self.out_dtype) return dout def _tiled_predict( self, inp: torch.Tensor, out_shape: Optional[Tuple[int]] = None ) -> torch.Tensor: """Tiled inference with overlapping input tiles. Tiling is not used if ``tile_shape`` and ``overlap_shape`` are undefined.""" if self.enable_tiling: if self.out_shape is None: raise ValueError('If you use tiling, you also need to supply out_shape.') out_shape = (inp.shape[0], *out_shape) return tiled_apply( self._predict, inp=inp, tile_shape=self.tile_shape, overlap_shape=self.overlap_shape, offset=self.offset, out_shape=out_shape, verbose=self.verbose ) # Otherwise: No tiling, apply model to the whole input in one step return self._predict(inp) def _splitbatch_predict( self, inp: torch.Tensor, num_batches: int, out_shape: Optional[Tuple[int]] = None ) -> torch.Tensor: """Split the input batch into smaller batches of the specified ``batch_size`` and perform inference on each of them separately.""" if self.out_shape is None: raise ValueError('If you define a batch_size, you also need to supply out_shape.') out = torch.empty((inp.shape[0], *self.out_shape), dtype=self.dtype) for k in range(0, num_batches): low = self.batch_size * k high = self.batch_size * (k + 1) out[low:high] = self._tiled_predict(inp[low:high], out_shape=out_shape) return out
[docs] def predict( self, inp: Union[np.ndarray, torch.Tensor], ) -> torch.Tensor: """ Perform prediction on ``inp`` and return prediction. Args: inp: Input data, e.g. of shape (N, C, H, W). Can be an ``np.ndarray`` or a ``torch.Tensor``. Note that ``inp`` is automatically converted to the specified ``dtype`` (default: ``torch.float32``) before inference. Returns: Model output """ if self.report_inp_stats: from elektronn3.data import utils try: print('input dist', utils.calculate_means(inp.numpy()), utils.calculate_stds(inp.numpy())) except: print('input dist', utils.calculate_means(inp), utils.calculate_stds(inp)) if self.transform is not None: if isinstance(inp, torch.Tensor): inp = inp.numpy() # transforms currently only work with numpy ndarrays as in/output transformed = np.empty_like(inp) for i in range(inp.shape[0]): # Apply transform for each sample of the batch separately transformed[i], _ = self.transform(inp[i], None) # target=None because we don't have any here inp = transformed if self.verbose: start = time.time() # Check/change out_shape for divisibility by tile_shape if self.enable_tiling: inp, out_shape, relevant_slice = self._ensure_matching_shapes(inp) else: relevant_slice = None out_shape = self.out_shape inp = torch.as_tensor(inp, dtype=self.dtype).contiguous() inp_batch_size = inp.shape[0] spatial_shape = np.array(inp.shape[2:]) # Lazily figure out these Predictor options based on the input it # receives if they are not already set. # Not sure if that's a good idea because these are object-changing # side-effects in the otherwise pure predict() function. if self.out_dtype is None: self.out_dtype = torch.uint8 if self.argmax_with_threshold is not None else inp.dtype if out_shape is not None and out_shape[0] > 255 and self.out_dtype == torch.uint8: raise ValueError(f'C = out_shape[0] = {out_shape[0]}, but out_dtype torch.uint8 can only hold values up to 255.') if self.tile_shape is None: self.tile_shape = spatial_shape if self.overlap_shape is None: self.overlap_shape = np.zeros_like(spatial_shape) if self.batch_size is None: self.batch_size = inp_batch_size num_batches = int(np.ceil(inp_batch_size / self.batch_size)) if num_batches == 1: # Predict everything in one step out = self._tiled_predict(inp=inp, out_shape=out_shape) else: # Split input batch into smaller batches and predict separately out = self._splitbatch_predict(inp=inp, num_batches=num_batches, out_shape=out_shape) # Explicit synchronization so the next GPU operation won't be # mysteriously slow. If we don't synchronize manually here, profilers # will misleadingly report a huge amount of time spent in out.cpu() if self.device.type == 'cuda': torch.cuda.synchronize() out = out.cpu() if relevant_slice is None else out[relevant_slice].cpu() if self.verbose: dtime = time.time() - start amount = out.numel() if out_shape is not None and np.array_equal(out_shape[2:], inp.shape[2:]): # calculate the amount of valid data produced amount = np.prod([*out.shape[:-3], *(out.shape[-3:] - 2 * self.overlap_shape)]) speed = amount / dtime / 1e6 print(f'Inference speed: {speed:.2f} MVox/s, time: {dtime:.2f}.') return out
# TODO: Make this work with input shape != output shape def _ensure_matching_shapes(self, inp: np.ndarray) -> Tuple[np.ndarray, Optional[Tuple[int]], Optional[slice]]: if self.out_shape is not None and np.any(self.out_shape[1:] % self.tile_shape): if self.strict_shapes: raise ValueError( 'Make sure that out_shape is divisible by tile_shape or ' 'relax this constraint by setting strict_shapes=False.' ) else: padded_out_shape = np.array(self.out_shape) padded_out_shape[1:] = np.ceil(self.out_shape[1:] / self.tile_shape) * self.tile_shape if self.offset is None: offset = np.zeros(shape=len(padded_out_shape) - 1, dtype=np.int64) else: offset = np.array(self.offset) padded_inp_shape = (*inp.shape[:2], *padded_out_shape[1:] + 2 * offset) padded_inp = np.zeros(padded_inp_shape) # Define the relevant region (that is: without the padding that was just added) relevant_slice_inp = _extend_nc([slice(0, d) for d in inp.shape[2:]]) relevant_slice_out = _extend_nc([slice(0, d) for d in self.out_shape[1:]]) padded_inp[relevant_slice_inp] = inp if self._warn_about_shapes and np.any(padded_out_shape != self.out_shape): sh_diff = np.subtract(padded_out_shape, self.out_shape) # Only nonzero elements are multiplied, otherwise it will be 0. wasted_pix = np.prod(sh_diff[sh_diff != 0]) total_pix = np.prod(padded_out_shape) wasted_percentage = 100 * wasted_pix / total_pix logger.info( f'Adapting out_shape {tuple(self.out_shape[1:])} to ' f'tile_shape {tuple(self.tile_shape)} ' f'by padding out_shape to {tuple(padded_out_shape[1:])}.\n' f'Suboptimal shapes will reduce execution speed.' # f'At least {wasted_percentage:.2f}% of total compute will be ' # f'wasted by this padding.' ) self._warn_about_shapes = False # TODO: Calculate exact compute waste by looking at increased tile overlaps # (the current estimation omits the (potentially high-impact) added per-tile # padding/overlaps via overlap_shape. else: padded_inp = inp padded_out_shape = self.out_shape relevant_slice_out = None return padded_inp, padded_out_shape, relevant_slice_out
[docs] def predict_proba(self, inp): logger.warning('Predictor.predict_proba(inp) is deprecated. Please use Predictor.predict(inp) instead.') return self.predict(inp)
# TODO: This can be replaced with a single model.load_state_dict(state_dict) call # after a while, because Trainer._save_model() now always saves unwrapped # modules if a parallel wrapper is detected. Or should we still keep this # for better support of models accidentally saved in wrapped state?
[docs] def set_state_dict(model: torch.nn.Module, state_dict: dict): """Set state dict of a model. Also works with ``torch.nn.DataParallel`` models.""" try: model.load_state_dict(state_dict) # If self.model was saved as nn.DataParallel then remove 'module.' prefix # in every key except RuntimeError: # TODO: Is it safe to catch all runtime errors here? new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k.replace('module.', '')] = v model.load_state_dict(new_state_dict)