# -*- coding: utf-8 -*-
# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany
# Authors: Marius Killinger, Martin Drawitsch, Philipp Schubert
import logging
import os
import numpy as np
from matplotlib import pyplot as plt
logger = logging.getLogger('elektronn3log')
def _scroll_plot1(image, name, init_z):
"""
Creates a plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: array of shape (z,x,y) or (z,x,y,RGB)
Usage
-----
For the scroll interaction to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = scroll_plot(image, name)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(111)
scroller = Scroller([ax1], [image, ], [name, ], init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _scroll_plot2(images, names, init_z):
"""
Creates a plot 1x2 image plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: list of 2 arrays
Each array of shape (z,y,x) or (z,y,x,RGB)
names: list of 2 strings
Names for each image
Usage
-----
For the scroll interaction to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = _scroll_plot4(images, names)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122, sharex=ax1, sharey=ax1)
scroller = Scroller([ax1, ax2], images, names, init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _scroll_plot4(images, names, init_z):
"""
Creates a plot 2x2 image plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: list of 4 arrays
Each array of shape (z,y,x) or (z,y,x,RGB)
names: list of 4 strings
Names for each image
Usage
-----
For the scroll interaction to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = _scroll_plot4(images, names)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222, sharex=ax1, sharey=ax1)
ax3 = fig.add_subplot(223, sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(224, sharex=ax1, sharey=ax1)
scroller = Scroller([ax1, ax2, ax3, ax4], images, names, init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
[docs]
def sma(c, n):
"""
Returns box-SMA of c with box length n, the returned array has the same
length as c and is const-padded at the beginning
"""
if n==0:
return c
ret = np.cumsum(c, dtype=float)
ret[n:] = (ret[n:] - ret[:-n]) / n
m = min(n, len(c))
ret[:n] = ret[:n] / np.arange(1, m + 1) # unsmoothed
return ret
[docs]
def add_timeticks(ax, times, steps, time_str='mins', num=5):
N = int(times[-1])
k = max(N / num, 1)
k = int(np.log10(k)) # 10-base of locators
m = int(np.round(float(N) / (num * 10 ** k))) # multiple of base
s = max(m * 10 ** k, 1)
x_labs = np.arange(0, N, s, dtype=np.int64)
x_ticks = np.interp(x_labs, times, steps)
ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labs)
ax.set_xlim(0, steps[-1])
ax.set_xlabel('Runtime [%s]' % time_str) # (%s)'%("{0:,d}".format(N)))
[docs]
def plot_hist(timeline, history, save_path, loss_smoothing_length=200,
autoscale=True):
"""Plot graphical info during Training"""
plt.ioff()
try:
# Subsample points for plotting
N = len(timeline)
x_timeline = np.arange(N)
s = max((len(timeline) // 2000), 1)
x_timeline = x_timeline[::s]
timeline = timeline[::s]
s = max((len(history) // 2000), 1)
history = history[::s]
if timeline['time'][-1] < 120 * 60:
runtime = str(int(timeline['time'][-1] / 60)) + ' mins'
else:
runtime = "%.1f hours" % (timeline['time'][-1] / 3600)
# check if valid data is available
if not np.any(np.isnan(history['valid_loss'])):
l = history['valid_loss'][-10:]
else:
l = history['train_loss'][-10:]
loss_cap = l.mean() + 2 * l.std()
lt = timeline['loss'][-200:]
lt_m = lt.mean()
lt_s = lt.std()
loss_cap_t = lt_m + 2 * lt_s
loss_cap = np.maximum(loss_cap, loss_cap_t)
if np.all(timeline['loss'] > 0):
loss_floor = 0.0
else:
loss_floor = lt_m - 2 * lt_s
### Timeline, Loss ###
plt.figure(figsize=(16, 12))
plt.subplot(211)
plt.plot(x_timeline, timeline['loss'], 'b-', alpha=0.5,
label='Update Loss')
loss_smooth = sma(timeline['loss'], loss_smoothing_length)
plt.plot(x_timeline, loss_smooth, 'k-', label='Smooth update Loss',
linewidth=3)
if autoscale:
plt.ylim(loss_floor, loss_cap)
plt.xlim(0, N)
plt.legend(loc=0)
plt.hlines(lt_m, 0, N, linestyle='dashed', colors='r', linewidth=2)
plt.hlines(lt_m + lt_s, 0, N, linestyle='dotted', colors='r',
linewidth=1)
plt.hlines(lt_m - lt_s, 0, N, linestyle='dotted', colors='r',
linewidth=1)
plt.xlabel('Update steps %s, total runtime %s' % (N - 1, runtime))
ax = plt.twiny()
if timeline['time'][-1] > 120 * 60:
add_timeticks(ax, timeline['time'] / 3600, x_timeline,
time_str='hours')
else:
add_timeticks(ax, timeline['time'] / 60, x_timeline,
time_str='mins')
### Loss vs Prevalence ###
plt.subplot(212)
c = 1.0 - (timeline['time'] / timeline['time'].max())
plt.scatter(timeline['batch_char'], timeline['loss'], c=c, marker='.',
s=80, cmap='gray', edgecolors='face')
if autoscale:
bc = timeline['batch_char'][-200:]
bc_m = bc.mean()
bc_s = bc.std()
bc_cap = bc_m + 2 * bc_s
if np.all(bc > 0):
bc_floor = -0.01
else:
bc_floor = bc_m - 2 * bc_s
plt.ylim(loss_floor, loss_cap)
plt.xlim(bc_floor, bc_cap + 1e-10)
plt.xlabel('Mean target of batch')
plt.ylabel('Loss')
plt.tight_layout()
plt.savefig(os.path.join(save_path, 'timeline.png'), bbox_inches='tight')
###################################################################
### History Loss ###
plt.figure(figsize=(16, 12))
plt.subplot(311)
plt.plot(history['steps'], history['train_loss'], 'g-',
label='Train Loss', linewidth=3)
plt.plot(history['steps'], history['valid_loss'], 'r-',
label='Valid Loss', linewidth=3)
if autoscale:
plt.ylim(loss_floor, loss_cap)
plt.xlim(0, history['steps'][-1])
plt.legend(loc=0)
# plt.xlabel('Update steps %s, total runtime %s'%(N-1, runtime))
ax = plt.twiny()
if timeline['time'][-1] > 120 * 60:
add_timeticks(ax, timeline['time'] / 3600, x_timeline,
time_str='hours')
else:
add_timeticks(ax, timeline['time'] / 60, x_timeline,
time_str='mins')
### History Loss gains ###
plt.subplot(312)
plt.plot(history['steps'], history['loss_gain'], 'b-',
label='Loss Gain at update', linewidth=3)
plt.hlines(0, 0, history['steps'][-1], linestyles='dotted')
plt.plot(history['steps'], history['lr'], 'r-', label='LR',
linewidth=3)
# plt.xlabel('Update steps %s, total runtime %s'%(N-1, runtime))
plt.legend(loc=3)
std = history['loss_gain'][:5].std() * 2 if len(history) > 6 else 1.0
if autoscale:
# add epsilon to suppress matplotlib warning in case of CG
plt.ylim(-std, std + 1e-10)
plt.xlim(0, history['steps'][-1])
ax2 = plt.twinx()
ax2.plot(history['steps'], history['mom'], 'r-', label='MOM')
ax2.plot(history['steps'], history['gradnetrate'], 'r-',
label='GradNetRate')
ax2.set_ylim(-1, 1)
if autoscale:
ax2.set_xlim(0, history['steps'][-1])
ax2.legend(loc=4)
### Errors ###
plt.subplot(313)
cutoff = 2
if len(history) > (cutoff + 1):
history = history[cutoff:]
# check if valid data is available
if not np.any(np.isnan(history['valid_err'])):
e = history['valid_err'][-200:]
else:
e = history['train_err'][-200:]
e_m = e.mean()
e_s = e.std()
err_cap = e_m + 2 * e_s
if np.all(e > 0):
err_floor = 0.0
else:
err_floor = e_m - 2 * e_s
plt.plot(history['steps'], history['train_err'], 'g--',
label='Train error', linewidth=1)
plt.plot(history['steps'], history['valid_err'], 'r--',
label='Valid Error', linewidth=1)
plt.plot(history['steps'], sma(history['train_err'], 8), 'g-',
label='Smooth train error', linewidth=3)
if not np.any(np.isnan(sma(history['valid_err'], 8))):
plt.plot(history['steps'], sma(history['valid_err'], 8), 'r-',
label='Smooth valid Error', linewidth=3)
if autoscale:
try:
plt.ylim(err_floor, err_cap)
except ValueError: # NaN errors can't be used for auto-scaling.
pass
plt.xlim(0, history['steps'][-1])
plt.grid()
plt.legend(loc=0)
plt.xlabel('Update steps %s, total runtime %s' % (N - 1, runtime))
plt.tight_layout()
plt.savefig(os.path.join(save_path, 'history.png'), bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.exception("An error occurred during plotting.")