Source code for elektronn3.training.train_utils

# -*- coding: utf-8 -*-
# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Marius Killinger, Philipp Schubert, Martin Drawitsch

import logging
import time
import warnings
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import torch

from elektronn3.training import plotting
from elektronn3 import floatX

logger = logging.getLogger('elektronn3log')


[docs]def create_preview_batch_from_knossos(knossos_preview_config: Dict[str, str]) -> torch.Tensor: from knossos_utils import KnossosDataset config = knossos_preview_config required_keys = ['dataset', 'size', 'offset', 'mag', 'target_mags'] for k in required_keys: if k not in config.keys(): raise ValueError(f'Required key {k} missing from knossos_preview_config.') logger.info(f'Loading preview region from dataset {config["dataset"]}') datasets = config['dataset'] if isinstance(datasets, str): datasets = [datasets] inp_np = [] for idx, dataset_path in enumerate(datasets): ds = KnossosDataset(dataset_path) inp_np.append(ds.load_raw( offset=config['offset'], size=config['size'], mag=config['mag'], datatype=np.float32 )) # (D, H, W) inp_np = np.stack(inp_np, axis=0) # C, D, H, W inp_np = inp_np[None] # (N, C, D, H, W) inp_np = inp_np / config.get('scale_brightness', 1.) inp = torch.from_numpy(inp_np) return inp
[docs]class HistoryTracker: def __init__(self): self.plotting_proc = None self.debug_outputs = None self.regression_track = None self.debug_output_names = None self.timeline = AccumulationArray(n_init=int(1e5), dtype=dict( names=[u'time', u'loss', u'batch_char', ], formats=[u'f4', ] * 3)) self.history = AccumulationArray(n_init=int(1e4), dtype=dict( names=[u'steps', u'time', u'train_loss', u'valid_loss', u'loss_gain', u'train_err', u'valid_err', u'lr', u'mom', u'gradnetrate'], formats=[u'i4', ] + [u'f4', ] * 9)) self.loss = AccumulationArray(n_init=int(1e5), data=[])
[docs] def update_timeline(self, vals): self.timeline.append(vals) self.loss.append(vals[1])
[docs] def register_debug_output_names(self, names): self.debug_output_names = names
[docs] def update_history(self, vals): self.history.append(vals)
[docs] def update_debug_outputs(self, vals): if self.debug_outputs is None: self.debug_outputs = AccumulationArray(n_init=int(1e5), right_shape=len(vals)) self.debug_outputs.append(vals)
[docs] def update_regression(self, pred, target): if self.regression_track is None: assert len(pred)==len(target) p = AccumulationArray(n_init=int(1e5), right_shape=len(pred)) t = AccumulationArray(n_init=int(1e5), right_shape=len(pred)) self.regression_track = [p, t] self.regression_track[0].append(pred) self.regression_track[1].append(target)
[docs] def plot(self, save_path=None, autoscale=True, close=True, loss_smoothing_len=200): plotting.plot_hist(self.timeline, self.history, save_path, loss_smoothing_len, autoscale) if self.debug_output_names and self.debug_outputs.length: plotting.plot_debug(self.debug_outputs, self.debug_output_names, save_path) if self.regression_track: plotting.plot_regression(self.regression_track[0], self.regression_track[1], save_path) plotting.plot_kde(self.regression_track[0], self.regression_track[1], save_path) if close: plt.close('all')
# TODO: Try to remove this thing (or document/rewrite it)
[docs]class AccumulationArray: def __init__(self, right_shape=(), dtype=floatX, n_init=100, data=None, ema_factor=0.95): if isinstance(dtype, dict) and right_shape!=(): raise ValueError("If dict is used as dtype, right shape must be" "unchanged (i.e it is 1d)") if data is not None and len(data): n_init += len(data) right_shape = data.shape[1:] dtype = data.dtype self._n_init = n_init if isinstance(right_shape, int): self._right_shape = (right_shape,) else: self._right_shape = tuple(right_shape) self.dtype = dtype self.length = 0 self._buffer = self._alloc(n_init) self._min = +np.inf self._max = -np.inf self._sum = 0 self._ema = None self._ema_factor = ema_factor if data is not None and len(data): self.length = len(data) self._buffer[:self.length] = data self._min = data.min(0) self._max = data.max(0) self._sum = data.sum(0) def __repr__(self): return repr(self.data) def _alloc(self, n): if isinstance(self._right_shape, (tuple, list, np.ndarray)): ret = np.zeros((n,) + tuple(self._right_shape), dtype=self.dtype) elif isinstance(self.dtype, dict): # rec array ret = np.zeros(n, dtype=self.dtype) else: raise ValueError("dtype not understood") return ret
[docs] def append(self, data): # data = self.normalise_data(data) if len(self._buffer)==self.length: tmp = self._alloc(len(self._buffer) * 2) tmp[:self.length] = self._buffer self._buffer = tmp if isinstance(self.dtype, dict): for k, val in enumerate(data): self._buffer[self.length][k] = data[k] else: self._buffer[self.length] = data if self._ema is None: self._ema = self._buffer[self.length] else: f = self._ema_factor fc = 1 - f self._ema = self._ema * f + self._buffer[self.length] * fc self.length += 1 with warnings.catch_warnings(): warnings.simplefilter('ignore', RuntimeWarning) self._min = np.minimum(data, self._min) self._max = np.maximum(data, self._max) self._sum = self._sum + np.asanyarray(data)
[docs] def add_offset(self, off): self.data[:] += off if off.ndim>np.ndim(self._sum): off = off[0] self._min += off self._max += off self._sum += off * self.length
[docs] def clear(self): self.length = 0 self._min = +np.inf self._max = -np.inf self._sum = 0
[docs] def mean(self): return np.asarray(self._sum, dtype=floatX) / self.length
[docs] def sum(self): return self._sum
[docs] def max(self): return self._max
[docs] def min(self): return self._min
def __len__(self): return self.length @property def data(self): return self._buffer[:self.length] @property def ema(self): return self._ema def __getitem__(self, slc): return self._buffer[:self.length][slc]
[docs]class Timer: def __init__(self): self.origin = time.time() self.t0 = self.origin @property def t_passed(self): return time.time() - self.origin
[docs]def pretty_string_time(t): """Custom printing of elapsed time""" if t > 4000: s = 't=%.1fh' % (t / 3600) elif t > 300: s = 't=%.0fm' % (t / 60) else: s = 't=%.0fs' % (t) return s