# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Ravil Dorozhinskii, Martin Drawitsch
import numpy as np
from itertools import product
from scipy import ndimage
import os
from elektronn3.data.utils import save_to_h5
from elektronn3.data.transforms.region_generator import RegionGenerator
import logging
logger = logging.getLogger('elektronn3log')
[docs]
class IncorrectLimits(Exception):
pass
[docs]
class IncorrectThreshold(Exception):
pass
[docs]
class IncorrectValue(Exception):
pass
[docs]
class IncorrectType(Exception):
pass
[docs]
class FunctionCallsCounter():
counter = 0
[docs]
class ScalarScheduler(object):
""" A scheduler for a scalar value within an iterative
process according to either linear or exponential growth. The user
specifies the initial value, the maximum one, growth type and the number
of steps within which the scalar value has to be gradually scaled.
At each iteration the user has to explicitly call step() to
update and modify the scalar value
If the user doesn't specify the maximum value or the interval,
the scalar value works as a constant.
"""
def __init__(self,
value: float,
max_value: float = None,
growth_type: str = None,
interval: int = None,
steps_per_report: int = None):
"""
Initializes all necessary variables and checks that
the initial value is less than the maximum one and
growth type is chosen correctly.
Parameters
----------
value - a scalar value at the beginning of a scheduled process
max_value - the scalar value at the end of a scheduled process
growth_type - type of growth: "lin" - linear; "exp" - exponential
interval - number of steps within which the scalar value has to be
increased from the initial value to the maximal one
steps_per_report - number of steps between information is updated on the screen
and written to a log file. The default value is None which means that
information won't be displayed and written to a log file
"""
if max_value and (value > max_value):
raise IncorrectLimits(f'ERROR: threshold limits are wrong: '
f'initial_threshold_value = {value}, '
f'max_threshold_value = {max_value}')
self.value = value
if max_value and interval:
self.max_value = max_value
self.interval = interval
if growth_type == "lin":
self.update_function = self.lin_update
self.base = (max_value - value) / self.interval
elif growth_type == "exp":
self.update_function = self.exp_update
self.base = np.power((self.max_value / self.value), 1.0 / self.interval)
else:
raise IncorrectValue('ERROR: ScalarScheduler class can only '
'take "growth_type" parameter with values '
f'either "lin" or "exp". Value "{growth_type}" '
'has been passed instead')
else:
self.update_function = self.idle_update
self.steps_per_report = steps_per_report
self.counter = 0
[docs]
def step(self) -> float:
""" Performs an update of the scheduled value
according to the growth type parameter chosen by the user
Returns
-------
the current scalar value
"""
self.update_function()
self._print_report()
return self.value
[docs]
def lin_update(self) -> None:
""" Performs an update of the scheduled value
according to the linear growth
Returns
-------
"""
self.value += self.base
self.value = min(self.value, self.max_value)
[docs]
def exp_update(self) -> None:
""" Performs an update of the scheduled value
according to the exponential growth
Returns
-------
"""
self.value *= self.base
self.value = min(self.value, self.max_value)
[docs]
def idle_update(self) -> None:
""" No-op function (keeps the scheduled value constant)
Returns
-------
"""
pass
def _print_report(self) -> None:
""" Prints the current scalar value on the screen
and writes it to a log file during an iterative process.
The function counts the number of step() calls and prints
information each time when the number of the calls
is divisible by 'steps_per_report'.
If the user doesn't pass the number of steps_per_report the function
doesn't print the information
Returns
-------
"""
if self.steps_per_report is not None:
if (self.counter % self.steps_per_report) == 0:
logger.info('ScalarScheduler: '
f'value: {self.value}, '
f'counter: {self.counter}')
self.counter += 1
[docs]
def check_random_data_blurring_config(patch_shape: list,
probability: float,
threshold: ScalarScheduler,
lower_lim_region_size: list,
upper_lim_region_size: list,
verbose: bool = False,
save_path: str = None,
num_steps_save: int = None) -> None:
""" Checks random data blurring parameters and ensures
that all parameters won't cause problems during apply_random_blurring
function calls. The function raises exceptions if a conflict is
detected. Use this function before a training procedure to be sure the config
fulfills the requirements posed by the apply_random_blurring function
Parameters
----------
patch_shape - shape of input samples
probability - probability of applying the random blurring algorithm
threshold - controls the level of data random blurring with respect to
input sample volume
lower_lim_region_size - min values of regions size along each axis
upper_lim_region_size - max values of regions size along each axis
verbose - mode that controls text information on the screen
save_path - path to the files that will contain a modified (blurred)
raw data input sample in the "h5" format
num_steps_save - number of steps between writing a modified (blurred)
raw data input sample in the "h5" format
Returns
-------
"""
# Check the user's specified dimensionality
if (len(lower_lim_region_size) != len(upper_lim_region_size) or
len(patch_shape) != len(lower_lim_region_size) or
len(patch_shape) != len(upper_lim_region_size)):
raise IncorrectLimits(f'ERROR: the region limits or/and input sample '
f'have different dimensionality:\n'
f'dimension of lower region limits: {len(lower_lim_region_size)}\n'
f'dimension of upper region limits: {len(upper_lim_region_size)}\n'
f'dimension of sample: {len(patch_shape)}')
# Check the user's specified region size
dim = len(patch_shape)
for i in range(dim):
if lower_lim_region_size[i] >= upper_lim_region_size[i]:
raise IncorrectLimits(f'ERROR: region limits are inconsistent at axis={i}:\n'
f'min = {lower_lim_region_size[i]}\n'
f'max = {upper_lim_region_size[i]}\n')
# Сheck whether the region size exceeds the input sample size
for i in range(dim):
if upper_lim_region_size[i] >= patch_shape[i]:
raise IncorrectLimits(f'ERROR: region size exceeds input sample at axis={i}:\n'
f'region size = {upper_lim_region_size[i]}\n'
f'sample size = {patch_shape[i]}\n')
# Check the data type of the threshold parameter
# The threshold must have its type of ScalarScheduler
if not isinstance(threshold, ScalarScheduler):
raise IncorrectType(f'ERROR: threshold type is not type of ScalarScheduler\n'
f'instead, it has its type of: {type(threshold)}')
# Check whether the threshold value specified by the user
# is within the range [0.0, 1.0]
if threshold.value < 0.0 or threshold.value > 1.0:
raise IncorrectLimits(f'ERROR: threshold of random data blurring is out '
f'of the range [0.0,1.0]: threshold = {threshold.value}')
# Check whether the probability value specified by the user
# is within the range [0.0, 1.0]
if probability < 0.0 or probability > 1.0:
raise IncorrectLimits(f'ERROR: probability of random data blurring is out '
f'of the limits [0.0,1.0]: probability = {probability}')
# Check whether the directory specified by the user exists.
# If no, try to create the directory using the path
# passed by the user to the function
if save_path is not None:
if not os.path.isdir(save_path):
os.makedirs(save_path)
[docs]
def apply_random_blurring(inp_sample: np.ndarray,
probability: float,
threshold: ScalarScheduler,
lower_lim_region_size: list,
upper_lim_region_size: list,
verbose: bool = False,
save_path: str = None,
num_steps_save: int = None) -> None:
""" Takes an input sample and applies random blurring.
At the beginning the function generates a random number within
the range [0,1) and compares it with the probability value passed
by the user. If the random number exceeds the probability value
the function terminates and returns the controls to the caller.
Otherwise, the function generates random regions within a raw data sample
volume until the total accumulated region volume exceeds that value
specified by means of the threshold parameter. The threshold denotes
the percentage of input sample volume that has to be filled in
by regions. Regions have different spatial shape which is randomly
generated within the ranges specified by the user.
Moreover, the volume within a region is blurred by a Gaussian filter.
Parameters
----------
inp_sample - input sample with the format (C, D, H, W)
probability - probability of applying the random blurring algorithm
threshold - controls the level of random blurring with respect to
the input sample volume
lower_lim_region_size - min values of region size along each axis
upper_lim_region_size - max values of region size along each axis
verbose - mode that controls text information on the screen
save_path - path to the files that will contain modified (blurred)
input sample in the "h5" format
num_steps_save - number of steps between writing modified (blurred)
raw data input sample in the "h5" format
Returns
-------
"""
if np.random.rand() > probability:
return
num_channels, sample_depth, sample_width, sample_height = inp_sample.shape
sample_volume = np.prod(inp_sample.shape[1:])
coord_bounds = [sample_depth, sample_width, sample_height]
generator = RegionGenerator(coord_bounds,
lower_lim_region_size,
upper_lim_region_size)
threshold.step()
for sample_indx in range(num_channels):
blurring_percentage = 0.0
intersection = set()
while blurring_percentage < threshold.value:
region = generator.create_region()
for k, i, j in product(range(region.coords_lo[0], region.coords_hi[0] + 1),
range(region.coords_lo[1], region.coords_hi[1] + 1),
range(region.coords_lo[2], region.coords_hi[2] + 1)):
intersection.add((k, i, j))
snippet = inp_sample[sample_indx,
region.coords_lo[0]:region.coords_hi[0] + 1,
region.coords_lo[1]:region.coords_hi[1] + 1,
region.coords_lo[2]:region.coords_hi[2] + 1]
gaussian_std = [np.random.randn() * size for size in region.size]
snippet = ndimage.gaussian_filter(snippet, gaussian_std)
inp_sample[sample_indx,
region.coords_lo[0]:region.coords_hi[0] + 1,
region.coords_lo[1]:region.coords_hi[1] + 1,
region.coords_lo[2]:region.coords_hi[2] + 1] = snippet
blurred_volume = len(intersection)
blurring_percentage = blurred_volume / sample_volume
if verbose:
logger.info(f'blur percentage for channel {sample_indx}: {blurring_percentage}')
if save_path and num_steps_save:
if (FunctionCallsCounter.counter % num_steps_save) == 0:
dictionary = {}
for i in range(num_channels):
dictionary[f'channel-{i}'] = inp_sample[i]
file_name = f'randomly_blurred_sample-{FunctionCallsCounter.counter}.h5'
save_to_h5(data=dictionary,
path=save_path + '/' + file_name,
overwrite=False,
compression=False)
FunctionCallsCounter.counter += 1