elektronn3.data.cnndata module

class elektronn3.data.cnndata.PatchCreator(input_sources, patch_shape, target_sources=None, offset=(0, 0, 0), cube_prios=None, aniso_factor=2, target_discrete_ix=None, input_discrete_ix=None, target_dtype=<class 'numpy.int64'>, train=True, warp_prob=False, warp_kwargs=None, epoch_size=100, transform=<elektronn3.data.transforms.transforms.Identity object>, in_memory=False, cube_meta=<elektronn3.data.cnndata._DefaultCubeMeta object>)[source]

Bases: torch.utils.data.

Dataset iterator class that creates 3D image patches from HDF5 files.

It implements the PyTorch Dataset interface and is meant to be used with a PyTorch DataLoader (or the modified elektronn3.training.trainer.train_utils.DelayedDataLoader, if it is used with elektronn3.training.trainer.Trainer`).

The main idea of this class is to automate input and target patch creation for training convnets for semantic image segmentation. Patches are sliced from random locations in the supplied HDF5 files (input_h5data, target_h5data). Optionally, the source coordinates from which patches are sliced are obtained by random warping with affine or perspective transformations for efficient augmentation and avoiding border artifacts (see warp_prob, warp_kwargs). Note that whereas other warping-based image augmentation systems usually warp images themselves, elektronn3 performs warping transformations on the coordinates from which image patches are sliced and obtains voxel values by interpolating between actual image voxels at the warped source locations (which are not confined to the original image’s discrete coordinate grid). (TODO: A visualization would be very helpful here to make this more clear) For more information about this warping mechanism see elektronn3.data.cnndata.warp_slice().

Currently, only 3-dimensional image data sets are supported, but 2D support is also planned.

  • input_sources (List[Tuple[str, str]]) – Sequence of (filename, hdf5_key) tuples, where each item specifies the filename and the HDF5 dataset key under which the input data is stored.

  • target_sources (Optional[List[Tuple[str, str]]]) – Sequence of (filename, hdf5_key) tuples, where each item specifies the filename and the HDF5 dataset key under which the target data is stored.

  • patch_shape (Sequence[int]) – Desired spatial shape of the samples that the iterator delivers by slicing from the data set files. Since this determines the size of input samples that are fed into the neural network, this is a very important value to tune. Making it too large can result in slow training and excessive memory consumption, but if it is too small, it can hinder the perceptive ability of the neural network because the samples it “sees” get too small to extract meaningful features. Adequate values for patch_shape are highly dependent on the data set (“How large are typical ROIs? How large does an image patch need to be so you can understand the input?”) and also depend on the neural network architecture to be used (If the effective receptive field of the network is small, larger patch sizes won’t help much).

  • offset (Sequence[int]) – Shape of the offset by which each the targets are cropped on each side. This needs to be set if the outputs of the network you train with are smaller than its inputs. For example, if the spatial shape of your inputs is patch_shape=(48, 96, 96) the spatial shape of your outputs is out_shape=(32, 56, 56), you should set offset=(8, 20, 20), because offset = (patch_shape - out_shape) / 2 should always hold true.

  • cube_prios (Optional[Sequence[float]]) – List of per-cube priorities, where a higher priority means that it is more likely that a sample comes from this cube.

  • aniso_factor (int) – Depth-anisotropy factor of the data set. E.g. if your data set has half resolution in the depth dimension, set aniso_factor=2. If all dimensions have the same resolution, set aniso_factor=1.

  • input_discrete_ix (Optional[List[int]]) – List of input channels that contain discrete values. By default (None), no channel is seen as discrete (generally inputs are real world images). This information is used to decide what kind of interpolation should be used for reading input data: - discrete targets are obtained by nearest-neighbor interpolation - non-discrete (continuous) targets are linearly interpolated.

  • target_discrete_ix (Optional[List[int]]) – List of target channels that contain discrete values. By default (None), every channel is seen as discrete (this is generally the case for classification tasks). See input_discrete_ix for the effect on target interpolation.

  • target_dtype (dtype) – dtype that target tensors should be cast to.

  • train (bool) – Determines if samples come from training or validation data. If True, training data is returned. If False, validation data is returned.

  • warp_prob (Union[bool, float]) – ratio of training samples that should be obtained using geometric warping augmentations.

  • warp_kwargs (Optional[Dict[str, Any]]) – kwargs that are passed through to elektronn3.data.coord_transforms.get_warped_slice(). See the docs of this function for information on kwargs options. Can be empty.

  • epoch_size (int) – Determines the length (__len__) of the Dataset iterator. epoch_size can be set to an arbitrary value and doesn’t have any effect on the content of produced training samples. It is recommended to set it to a suitable value for one “training phase”, so after each epoch_size batches, validation/logging/plotting are performed by the training loop that uses this data set (e.g. elektronn3.training.trainer.Trainer).

  • transform (Callable) – Transformation function to be applied to (inp, target) samples (for normalization, data augmentation etc.). The signature is always inp, target = transform(inp, target), where inp and target both are numpy.ndarray``s. In some transforms ``target can also be set to None. In this case it is ignored and only inp is processed. To combine multiple transforms, use elektronn3.data.transforms.Compose. See elektronn3.data.transforms. for some implementations.

  • in_memory (bool) – If True, all data set files are immediately loaded into host memory and are permanently kept there as numpy arrays. If this is disabled (default), file contents are always read from the HDF5 files to produce samples. (Note: This does not mean it’s slower, because file contents are transparently cached by h5py, see http://docs.h5py.org/en/latest/high/file.html#chunk-cache).


Check if all files are accessible.

Return type


Return type


Return type

Tuple[List[DataSource], Optional[List[DataSource]]]

Return type


warp_cut(inp_src, target_src, warp_prob, warp_kwargs)[source]

(Wraps elektronn3.data.coord_transforms.get_warped_slice())

Cuts a warped slice out of the input and target arrays. The same random warping transformation is each applied to both input and target.

Warping is randomly applied with the probability defined by the warp_prob parameter (see below).

  • inp_src (h5py.Dataset) – Input image source (in HDF5)

  • target_src (h5py.Dataset) – Target image source (in HDF5)

  • warp_prob (float or bool) – False/True disable/enable warping completely. If warp_prob is a float, it is used as the ratio of inputs that should be warped. E.g. 0.5 means approx. every second call to this function actually applies warping to the image-target pair.

  • warp_kwargs (dict) – kwargs that are passed through to elektronn2.data.coord_transforms.get_warped_slice(). Can be empty.

Return type

Tuple[ndarray, Optional[ndarray]]


  • inp (np.ndarray) – (Warped) input image slice

  • target_src (np.ndarray) – (Warped) target slice

property warp_stats: str
Return type


class elektronn3.data.cnndata.Reconstruction2d(*args: Any, **kwargs: Any)[source]

Bases: torch.utils.data.

Simple dataset for 2d reconstruction for auto-encoders etc..

class elektronn3.data.cnndata.Segmentation2d(inp_paths, target_paths, transform=<elektronn3.data.transforms.transforms.Identity object>, offset=(0, 0, 0), in_memory=True, inp_dtype=<class 'numpy.float32'>, target_dtype=<class 'numpy.int64'>, epoch_multiplier=1)[source]

Bases: torch.utils.data.

Simple dataset for 2d segmentation.

Expects a list of input_paths and target_paths where target_paths[i] is the target of input_paths[i] for all i.

Return type


class elektronn3.data.cnndata.SimpleNeuroData2d(inp_path=None, target_path=None, train=True, inp_key='raw', target_key='lab', pool=(1, 1, 1), transform=<elektronn3.data.transforms.transforms.Identity object>, out_channels=None)[source]

Bases: torch.utils.data.

2D Dataset class for neuro_data_cdhw, reading from a single HDF5 file.

Delivers 2D image slices from the (H, W) plane at given D indices. Not scalable, keeps everything in memory. This is just a minimalistic proof of concept.