Source code for elektronn3.data.cnndata

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Martin Drawitsch, Philipp Schubert

__all__ = ['PatchCreator', 'SimpleNeuroData2d', 'Segmentation2d', 'Reconstruction2d']

import glob
import logging
import os
import sys
import traceback
from os.path import expanduser
from typing import Tuple, Dict, Optional, Union, Sequence, Any, List, Callable

import h5py
import imageio
import numpy as np
import torch
from torch.utils import data

from elektronn3.data import coord_transforms, transforms
from elektronn3.data.sources import DataSource, HDF5DataSource, slice_3d

logger = logging.getLogger('elektronn3log')


class _DefaultCubeMeta:
    def __getitem__(self, *args, **kwargs): return np.inf


# TODO: Document passing DataSources directly
[docs] class PatchCreator(data.Dataset): """Dataset iterator class that creates 3D image patches from HDF5 files. It implements the PyTorch ``Dataset`` interface and is meant to be used with a PyTorch ``DataLoader`` (or the modified :py:class:`elektronn3.training.trainer.train_utils.DelayedDataLoader`, if it is used with :py:class:`elektronn3.training.trainer.Trainer``). The main idea of this class is to automate input and target patch creation for training convnets for semantic image segmentation. Patches are sliced from random locations in the supplied HDF5 files (``input_h5data``, ``target_h5data``). Optionally, the source coordinates from which patches are sliced are obtained by random warping with affine or perspective transformations for efficient augmentation and avoiding border artifacts (see ``warp_prob``, ``warp_kwargs``). Note that whereas other warping-based image augmentation systems usually warp images themselves, elektronn3 performs warping transformations on the **coordinates** from which image patches are sliced and obtains voxel values by interpolating between actual image voxels at the warped source locations (which are not confined to the original image's discrete coordinate grid). (TODO: A visualization would be very helpful here to make this more clear) For more information about this warping mechanism see :py:meth:`elektronn3.data.cnndata.warp_slice()`. Currently, only 3-dimensional image data sets are supported, but 2D support is also planned. Args: input_sources: Sequence of ``(filename, hdf5_key)`` tuples, where each item specifies the filename and the HDF5 dataset key under which the input data is stored. target_sources: Sequence of ``(filename, hdf5_key)`` tuples, where each item specifies the filename and the HDF5 dataset key under which the target data is stored. patch_shape: Desired spatial shape of the samples that the iterator delivers by slicing from the data set files. Since this determines the size of input samples that are fed into the neural network, this is a very important value to tune. Making it too large can result in slow training and excessive memory consumption, but if it is too small, it can hinder the perceptive ability of the neural network because the samples it "sees" get too small to extract meaningful features. Adequate values for ``patch_shape`` are highly dependent on the data set ("How large are typical ROIs? How large does an image patch need to be so you can understand the input?") and also depend on the neural network architecture to be used (If the effective receptive field of the network is small, larger patch sizes won't help much). offset: Shape of the offset by which each the targets are cropped on each side. This needs to be set if the outputs of the network you train with are smaller than its inputs. For example, if the spatial shape of your inputs is ``patch_shape=(48, 96, 96)`` the spatial shape of your outputs is ``out_shape=(32, 56, 56)``, you should set ``offset=(8, 20, 20)``, because ``offset = (patch_shape - out_shape) / 2`` should always hold true. cube_prios: List of per-cube priorities, where a higher priority means that it is more likely that a sample comes from this cube. aniso_factor: Depth-anisotropy factor of the data set. E.g. if your data set has half resolution in the depth dimension, set ``aniso_factor=2``. If all dimensions have the same resolution, set ``aniso_factor=1``. input_discrete_ix: List of input channels that contain discrete values. By default (``None``), no channel is seen as discrete (generally inputs are real world images). This information is used to decide what kind of interpolation should be used for reading input data: - discrete targets are obtained by nearest-neighbor interpolation - non-discrete (continuous) targets are linearly interpolated. target_discrete_ix: List of target channels that contain discrete values. By default (``None``), every channel is seen as discrete (this is generally the case for classification tasks). See input_discrete_ix for the effect on target interpolation. target_dtype: dtype that target tensors should be cast to. train: Determines if samples come from training or validation data. If ``True``, training data is returned. If ``False``, validation data is returned. warp_prob: ratio of training samples that should be obtained using geometric warping augmentations. warp_kwargs: kwargs that are passed through to :py:meth:`elektronn3.data.coord_transforms.get_warped_slice()`. See the docs of this function for information on kwargs options. Can be empty. epoch_size: Determines the length (``__len__``) of the ``Dataset`` iterator. ``epoch_size`` can be set to an arbitrary value and doesn't have any effect on the content of produced training samples. It is recommended to set it to a suitable value for one "training phase", so after each ``epoch_size`` batches, validation/logging/plotting are performed by the training loop that uses this data set (e.g. ``elektronn3.training.trainer.Trainer``). transform: Transformation function to be applied to ``(inp, target)`` samples (for normalization, data augmentation etc.). The signature is always ``inp, target = transform(inp, target)``, where ``inp`` and ``target`` both are ``numpy.ndarray``s. In some transforms ``target`` can also be set to ``None``. In this case it is ignored and only ``inp`` is processed. To combine multiple transforms, use :py:class:`elektronn3.data.transforms.Compose`. See :py:mod:`elektronn3.data.transforms`. for some implementations. in_memory: If ``True``, all data set files are immediately loaded into host memory and are permanently kept there as numpy arrays. If this is disabled (default), file contents are always read from the HDF5 files to produce samples. (Note: This does not mean it's slower, because file contents are transparently cached by h5py, see http://docs.h5py.org/en/latest/high/file.html#chunk-cache). """ def __init__( self, input_sources: List[Tuple[str, str]], patch_shape: Sequence[int], target_sources: Optional[List[Tuple[str, str]]] = None, offset: Sequence[int] = (0, 0, 0), cube_prios: Optional[Sequence[float]] = None, aniso_factor: int = 2, target_discrete_ix: Optional[List[int]] = None, input_discrete_ix: Optional[List[int]] = None, target_dtype: np.dtype = np.int64, train: bool = True, warp_prob: Union[bool, float] = False, warp_kwargs: Optional[Dict[str, Any]] = None, epoch_size: int = 100, transform: Callable = transforms.Identity(), in_memory: bool = False, cube_meta=_DefaultCubeMeta(), ): # Early checks if target_sources is not None and len(input_sources) != len(target_sources): raise ValueError( 'If target_sources is not None, input_sources and ' 'target_sources must be lists of same length.' ) if not train: if warp_prob > 0: logger.warning('Augmentations should not be used on validation data.') # batch properties self.train = train self.warp_prob = warp_prob self.warp_kwargs = warp_kwargs if warp_kwargs is not None else {} # general properties self.input_sources = input_sources self.target_sources = target_sources self.cube_meta = cube_meta self.cube_prios = cube_prios self.aniso_factor = aniso_factor self.target_discrete_ix = target_discrete_ix self.input_discrete_ix = input_discrete_ix self.epoch_size = epoch_size self._orig_epoch_size = epoch_size # Store original epoch_size so it can be reset later. self.in_memory = in_memory self.patch_shape = np.array(patch_shape, dtype=np.int) self.ndim = self.patch_shape.ndim self.offset = np.array(offset) self.target_patch_shape = self.patch_shape - self.offset * 2 self._target_dtype = target_dtype self.transform = transform # Setup internal stuff self.pid = os.getpid() # The following fields will be filled when reading data self.n_labelled_pixels = 0 self.inputs: List[DataSource] = [] self.targets: List[DataSource] = [] self.load_data() # Open dataset files self.n_successful_warp = 0 self.n_failed_warp = 0 self._failed_warp_warned = False def __getitem__(self, index: int) -> Dict[str, Any]: # Note that the index is ignored. Samples are always random return self._get_random_sample() def _get_random_sample(self) -> Dict[str, Any]: input_src, target_src, i = self._getcube() # get cube randomly warp_prob = self.warp_prob while True: try: inp, target = self.warp_cut(input_src, target_src, warp_prob, self.warp_kwargs) if target is not None: target = target.astype(self._target_dtype) except coord_transforms.WarpingOOBError as e: # Temporarily set warp_prob to 1 to make sure that the next attempt # will also try to use warping. Otherwise, self.warp_prob would not # reflect the actual probability of a sample being obtained by warping. warp_prob = 1 if warp_prob > 0 else 0 self.n_failed_warp += 1 if self.n_failed_warp > 20 and self.n_failed_warp > 8 * self.n_successful_warp and not self._failed_warp_warned: fail_ratio = self.n_failed_warp / (self.n_failed_warp + self.n_successful_warp) fail_percentage = int(round(100 * fail_ratio)) print(e) logger.warning( f'{fail_percentage}% of warping attempts are failing.\n' 'Consider lowering lowering your input patch shapes or warp_kwargs[\'warp_amount\']).' ) self._failed_warp_warned = True continue except coord_transforms.WarpingSanityError: logger.exception('Invalid coordinate values encountered while warping. Retrying...') continue self.n_successful_warp += 1 try: inp, target = self.transform(inp, target) except transforms._DropSample: # A filter transform has chosen to drop this sample, so skip it logger.debug('Sample dropped.') continue break inp = torch.as_tensor(inp) cube_meta = torch.as_tensor(self.cube_meta[i]) fname = os.path.basename(self.inputs[i].fname) sample = { 'inp': inp, 'cube_meta': cube_meta, # TODO: Make cube_meta completely optional again 'fname': fname } if target is not None: sample['target'] = torch.as_tensor(target) return sample def __len__(self) -> int: return self.epoch_size # TODO: Write a good __repr__(). The version below is completely outdated. # def __repr__(self) -> str: # s = "{0:,d}-target Data Set with {1:,d} input channel(s):\n" + \ # "#train cubes: {2:,d} and #valid cubes: {3:,d}, {4:,d} labelled " + \ # "pixels." # s = s.format(self.c_target, self.c_input, self._training_count, # self._valid_count, self.n_labelled_pixels) # return s @property def warp_stats(self) -> str: return "Warp stats: successful: %i, failed %i, quota: %.1f" %( self.n_successful_warp, self.n_failed_warp, float(self.n_successful_warp)/(self.n_failed_warp+self.n_successful_warp))
[docs] def warp_cut( self, inp_src: DataSource, target_src: Optional[DataSource], warp_prob: Union[float, bool], warp_kwargs: Dict[str, Any] ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ (Wraps :py:meth:`elektronn3.data.coord_transforms.get_warped_slice()`) Cuts a warped slice out of the input and target arrays. The same random warping transformation is each applied to both input and target. Warping is randomly applied with the probability defined by the ``warp_prob`` parameter (see below). Parameters ---------- inp_src: h5py.Dataset Input image source (in HDF5) target_src: h5py.Dataset Target image source (in HDF5) warp_prob: float or bool False/True disable/enable warping completely. If ``warp_prob`` is a float, it is used as the ratio of inputs that should be warped. E.g. 0.5 means approx. every second call to this function actually applies warping to the image-target pair. warp_kwargs: dict kwargs that are passed through to :py:meth:`elektronn2.data.coord_transforms.get_warped_slice()`. Can be empty. Returns ------- inp: np.ndarray (Warped) input image slice target_src: np.ndarray (Warped) target slice """ if (warp_prob is True) or (warp_prob == 1): # always warp do_warp = True elif 0 < warp_prob < 1: # warp only a fraction of examples do_warp = True if (np.random.rand() < warp_prob) else False else: # never warp do_warp = False if not do_warp: warp_kwargs = dict(warp_kwargs) warp_kwargs['warp_amount'] = 0 if target_src is None: target_src_shape = None target_patch_shape = None else: target_src_shape = target_src.shape target_patch_shape = self.target_patch_shape M = coord_transforms.get_warped_coord_transform( inp_src_shape=inp_src.shape, patch_shape=self.patch_shape, aniso_factor=self.aniso_factor, target_src_shape=target_src_shape, target_patch_shape=target_patch_shape, **warp_kwargs ) inp, target = coord_transforms.warp_slice( inp_src=inp_src, patch_shape=self.patch_shape, M=M, target_src=target_src, target_patch_shape=target_patch_shape, target_discrete_ix=self.target_discrete_ix, input_discrete_ix=self.input_discrete_ix ) return inp, target
def _getcube(self) -> Tuple[DataSource, DataSource, int]: """ Draw an example cube according to sampling weight on training data, or randomly on valid data """ i = np.random.choice( np.arange(len(self.cube_prios)), p=self.cube_prios / np.sum(self.cube_prios) ) inp_source = self.inputs[i] target_source = None if self.targets is None else self.targets[i] return inp_source, target_source, i
[docs] def load_data(self) -> None: if len(self.inputs) == len(self.targets) == 0: inp_files, target_files = self.open_files() self.inputs.extend(inp_files) if target_files is None: self.targets = None else: self.targets.extend(target_files) else: logger.info('Using directly specified data sources.') if self.cube_prios is None: # If no priorities are given: sample proportionally to target sizes # if available, or else w.r.t. input sizes (voxel counts) self.cube_prios = [] if self.targets is None: self.cube_prios = [inp.size for inp in self.inputs] else: self.cube_prios = [target.size for target in self.targets] self.cube_prios = np.array(self.cube_prios, dtype=np.float32) / np.sum(self.cube_prios) logger.debug(f'cube_prios = {self.cube_prios}')
[docs] def check_files(self) -> None: """ Check if all files are accessible. """ notfound = False give_neuro_data_hint = False fullpaths = [f for f, _ in self.input_sources] if self.target_sources is not None: fullpaths.extend([f for f, _ in self.target_sources]) for p in fullpaths: if not os.path.exists(p): print('{} not found.'.format(p)) notfound = True if 'neuro_data_cdhw' in p: give_neuro_data_hint = True if give_neuro_data_hint: print('\nIt looks like you are referencing the neuro_data_cdhw dataset.\n' 'To install the neuro_data_xzy dataset to the default location, run:\n' ' $ wget https://github.com/ELEKTRONN/elektronn.github.io/releases/download/neuro_data_cdhw/neuro_data_cdhw.zip\n' ' $ unzip neuro_data_cdhw.zip -d ~/neuro_data_cdhw') if notfound: print('\nPlease fetch the necessary dataset and/or ' 'change the relevant file paths in the network config.') sys.stdout.flush() sys.exit(1)
[docs] def open_files(self) -> Tuple[List[DataSource], Optional[List[DataSource]]]: self.check_files() inp_sources, target_sources = [], [] modestr = 'Training' if self.train else 'Validation' memstr = ' (in memory)' if self.in_memory else '' logger.info(f'\n{modestr} data set{memstr}:') if self.target_sources is None: for (inp_fname, inp_key), cube_meta in zip(self.input_sources, self.cube_meta): inp_source = HDF5DataSource(fname=inp_fname, key=inp_key, in_memory=self.in_memory) logger.info(f' input: {inp_fname}[{inp_key}]: {inp_source.shape} ({inp_source.dtype})') if not np.all(cube_meta == np.inf): logger.info(f' cube_meta: {cube_meta}') inp_sources.append(inp_source) target_sources = None else: for (inp_fname, inp_key), (target_fname, target_key), cube_meta in zip(self.input_sources, self.target_sources, self.cube_meta): inp_source = HDF5DataSource(fname=inp_fname, key=inp_key, in_memory=self.in_memory) target_source = HDF5DataSource(fname=target_fname, key=target_key, in_memory=self.in_memory) logger.info(f' input: {inp_fname}[{inp_key}]: {inp_source.shape} ({inp_source.dtype})') logger.info(f' with target: {target_fname}[{target_key}]: {target_source.shape} ({target_source.dtype})') if not np.all(cube_meta == np.inf): logger.info(f' cube_meta: {cube_meta}') inp_sources.append(inp_source) target_sources.append(target_source) logger.info('') return inp_sources, target_sources
[docs] def set_offset(self, offset: Sequence[int]) -> None: self.offset = np.array(offset) self.target_patch_shape = self.patch_shape - self.offset * 2
def get_preview_batch( h5data: Tuple[str, str], preview_shape: Optional[Tuple[int, ...]] = None, transform: Optional[Callable] = None, in_memory: bool = False, dim: Optional[float] = None, ) -> torch.Tensor: fname, key = h5data inp_h5 = h5py.File(fname, 'r')[key] if in_memory: inp_h5 = inp_h5.value if dim is None: if preview_shape is None: raise ValueError('At least one of preview_shape, dim must be defined.') dim = len(preview_shape) # 2D or 3D inp_shape = np.array(inp_h5.shape[-dim:]) if preview_shape is None: # Slice everything inp_lo = np.zeros_like(inp_shape) inp_hi = inp_shape else: # Slice only a preview_shape-sized region from the center of the input halfshape = np.array(preview_shape) // 2 inp_center = inp_shape // 2 inp_lo = inp_center - halfshape inp_hi = inp_center + halfshape if np.any(inp_center < halfshape): raise ValueError( 'preview_shape is too big for shape of input source.' f'Requested {preview_shape}, but can only deliver {tuple(inp_shape)}.' ) memstr = ' (in memory)' if in_memory else '' logger.info(f'\nPreview data{memstr}:') logger.info(f' input: {fname}[{key}]: {inp_h5.shape} ({inp_h5.dtype})\n') inp_np = slice_3d(inp_h5, inp_lo, inp_hi, prepend_empty_axis=True) if inp_np.ndim == dim + 1: # Should be dim + 2 for (N, C) dims inp_np = inp_np[:, None] # Add missing C dim if transform is not None: for n in range(inp_np.shape[0]): # N is usually 1, so this is only iterated once with n=0 inp_np[0], _ = transform(inp_np[0], None) inp = torch.from_numpy(inp_np) return inp
[docs] class SimpleNeuroData2d(data.Dataset): """ 2D Dataset class for neuro_data_cdhw, reading from a single HDF5 file. Delivers 2D image slices from the (H, W) plane at given D indices. Not scalable, keeps everything in memory. This is just a minimalistic proof of concept. """ def __init__( self, inp_path=None, target_path=None, train=True, inp_key='raw', target_key='lab', # offset=(0, 0, 0), pool=(1, 1, 1), transform: Callable = transforms.Identity(), out_channels: Optional[int] = None, ): super().__init__() self.transform = transform self.out_channels = out_channels cube_id = 0 if train else 2 if inp_path is None: inp_path = expanduser(f'~/neuro_data_cdhw/raw_{cube_id}.h5') if target_path is None: target_path = expanduser(f'~/neuro_data_cdhw/barrier_int16_{cube_id}.h5') self.inp_file = h5py.File(os.path.expanduser(inp_path), 'r') self.target_file = h5py.File(os.path.expanduser(target_path), 'r') self.inp = self.inp_file[inp_key][()].astype(np.float32) self.target = self.target_file[target_key][()].astype(np.int64) self.target = self.target[0] # Squeeze superfluous first dimension self.target = self.target[::pool[0], ::pool[1], ::pool[2]] # Handle pooling (dirty hack TODO) # Cut inp and target to same size inp_shape = np.array(self.inp.shape[1:]) target_shape = np.array(self.target.shape) diff = inp_shape - target_shape offset = diff // 2 # offset from image boundaries self.inp = self.inp[ :, offset[0]: inp_shape[0] - offset[0], offset[1]: inp_shape[1] - offset[1], offset[2]: inp_shape[2] - offset[2], ] self.close_files() # Using file contents from memory -> no need to keep the file open. def __getitem__(self, index): # Get z slices inp = self.inp[:, index] target = self.target[index] inp, target = self.transform(inp, target) inp = torch.as_tensor(inp) target = torch.as_tensor(target) sample = { 'inp': inp, 'target': target, 'cube_meta': np.inf, 'fname': str(index) } return sample def __len__(self): return self.target.shape[0]
[docs] def close_files(self): self.inp_file.close() self.target_file.close()
# TODO: docs, types
[docs] class Segmentation2d(data.Dataset): """Simple dataset for 2d segmentation. Expects a list of ``input_paths`` and ``target_paths`` where ``target_paths[i]`` is the target of ``input_paths[i]`` for all i. """ def __init__( self, inp_paths, target_paths, transform=transforms.Identity(), offset: Sequence[int] = (0 ,0, 0), in_memory=True, inp_dtype=np.float32, target_dtype=np.int64, epoch_multiplier=1, # Pretend to have more data in one epoch ): super().__init__() self.inp_paths = inp_paths self.target_paths = target_paths self.transform = transform self.offset = offset self.in_memory = in_memory self.inp_dtype = inp_dtype self.target_dtype = target_dtype self.epoch_multiplier = epoch_multiplier def load_image(fname): inp = imageio.imread(fname).astype(np.float32) if inp.ndim == 2: inp = inp[None] # (H, W) -> (C=1, H, W) elif inp.ndim == 3: inp = inp.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) else: raise RuntimeError(f'Image {fname} has shape {inp.shape}, but ndim should be 2 or 3.') return inp if self.in_memory: self.inputs = [] rgb_fnames = {} gray_fnames = {} for input_path in self.inp_paths: if os.path.isdir(input_path): multi_input = [] for channel_idx, input_file in enumerate(sorted(glob.glob(str(input_path) + '/*'))): inp = load_image(str(input_file)) if inp.shape[0] == 1: gray_fnames[channel_idx] = input_file elif inp.shape[0] == 3: rgb_fnames[channel_idx] = input_file rgb_fname = rgb_fnames.get(channel_idx) if rgb_fname is not None and inp.shape[0] == 1: raise RuntimeError(f'GT input layer {channel_idx} has mixed multi-channel ({rgb_fname}) and single-channel images ({input_file}).') gray_fname = gray_fnames.get(channel_idx) if gray_fname is not None and inp.shape[0] == 3: raise RuntimeError(f'GT input layer {channel_idx} has mixed multi-channel ({input_file}) and single-channel images ({gray_fname}).') multi_input.append(inp) self.inputs.append(np.concatenate(multi_input)) else: inp = load_image(input_path) if inp.shape[0] == 1: gray_fnames[0] = input_path elif inp.shape[0] == 3: rgb_fnames[0] = input_path if len(rgb_fnames) > 0 and inp.shape[0] == 1 or len(gray_fnames) > 0 and inp.shape[0] == 3: raise RuntimeError(f'Mixed multi-channel ({rgb_fnames[0]}) and single-channel images ({gray_fnames[0]}) in gt.') self.inputs.append(inp) self.targets = [ np.array(imageio.imread(fname)).astype(np.int64) for fname in self.target_paths ] def __getitem__(self, index): index %= len(self.inp_paths) # Wrap around to support epoch_multiplier if self.in_memory: inp = self.inputs[index] target = self.targets[index] else: fname = self.inp_paths[index] inp = imageio.imread(fname).astype(np.float32) if inp.ndim == 2: inp = inp[None] # (H, W) -> (C=1, H, W) elif inp.ndim == 3: inp = inp.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) else: raise RuntimeError(f'Image {fname} has shape {inp.shape}, but ndim should be 2 or 3.') target = np.array(imageio.imread(self.target_paths[index]), dtype=np.int64) while True: # Only makes sense if RandomCrop is used try: inp, target = self.transform(inp, target) break except transforms._DropSample: pass if np.any(self.offset): off = self.offset target = target[off[0]:-off[0], off[1]:-off[1]] sample = { 'inp': torch.as_tensor(inp.astype(self.inp_dtype)), 'target': torch.as_tensor(target.astype(self.target_dtype)), 'cube_meta': np.inf, 'fname': str(self.inp_paths[index]) } return sample def __len__(self): return len(self.target_paths) * self.epoch_multiplier
[docs] def set_offset(self, offset: Sequence[int]) -> None: self.offset = offset
# TODO: Document
[docs] class Reconstruction2d(data.Dataset): """Simple dataset for 2d reconstruction for auto-encoders etc.. """ def __init__( self, inp_paths, transform=transforms.Identity(), in_memory=True, inp_dtype=np.float32, epoch_multiplier=1, # Pretend to have more data in one epoch ): super().__init__() self.inp_paths = inp_paths self.transform = transform self.in_memory = in_memory self.inp_dtype = inp_dtype self.epoch_multiplier = epoch_multiplier if self.in_memory: self.inputs = [ np.array(imageio.imread(fname)).astype(np.float32)[None] for fname in self.inp_paths ] def __getitem__(self, index): index %= len(self.inp_paths) # Wrap around to support epoch_multiplier if self.in_memory: inp = self.inputs[index] else: inp = np.array(imageio.imread(self.inp_paths[index]), dtype=self.inp_dtype) if inp.ndim == 2: # (H, W) inp = inp[None] # (C=1, H, W) while True: # Only makes sense if RandomCrop is used try: inp, _ = self.transform(inp, None) break except transforms._DropSample: pass inp = torch.as_tensor(inp) sample = { 'inp': inp, 'target': inp, 'cube_meta': np.inf, 'fname': str(self.inp_paths[index]) } return sample def __len__(self): return len(self.inp_paths) * self.epoch_multiplier
class TripletData2d(data.Dataset): """Simple dataset for 2D triplet loss training. """ def __init__( self, inp_paths, transform=transforms.Identity(), invariant_transform=None, in_memory=True, inp_dtype=np.float32, epoch_multiplier=1, # Pretend to have more data in one epoch ): super().__init__() self.inp_paths = inp_paths self.transform = transform self.invariant_transform = invariant_transform self.in_memory = in_memory self.inp_dtype = inp_dtype self.epoch_multiplier = epoch_multiplier if self.in_memory: self.inputs = [ np.array(imageio.imread(fname)).astype(np.float32)[None] for fname in self.inp_paths ] def _get(self, index): if self.in_memory: inp = self.inputs[index] else: inp = np.array(imageio.imread(self.inp_paths[index]), dtype=self.inp_dtype) if inp.ndim == 2: # (H, W) inp = inp[None] # (C=1, H, W) while True: # Only makes sense if RandomCrop is used try: inp, _ = self.transform(inp, None) break except transforms._DropSample: pass return inp def _randidx_excluding(self, exclude): while True: idx = np.random.randint(0, len(self.inp_paths) // self.epoch_multiplier) if idx != exclude: return idx def __getitem__(self, index): index %= len(self.inp_paths) # Wrap around to support epoch_multiplier anchor = self._get(index) if self.invariant_transform is None: # Assuming a random augmentation transform, the positive image will be different than # the anchor, but it will originate from the same image file. # If random cropping and geometrical transforms are used, make sure that the loss is # not calculated on localized/spatial outputs! pos = self._get(index) else: # Apply an additional transform against which the network should learn invariant behavior pos, _ = self.invariant_transform(anchor, None) # Sample a negative image from a random different index -> different image neg_idx = self._randidx_excluding(index) neg = self._get(neg_idx) if self.invariant_transform is not None: # Also apply the invariant transform to the negative image because otherwise # the model could "cheat" by detecting that the inherent features of this # transform only exist in the positive image. neg, _ = self.invariant_transform(neg, None) sample = { 'anchor': torch.as_tensor(anchor), 'pos': torch.as_tensor(pos), 'neg': torch.as_tensor(neg), 'fname': f'ap{index}n{neg_idx}' } return sample def __len__(self): return len(self.inp_paths) * self.epoch_multiplier # TODO: Warn if datasets have no content