# These are default plotting handlers that work in some common training
# scenarios, but won't work in every case:
import logging
import os
from typing import Dict, Optional, Callable, Sequence
import matplotlib.figure
import matplotlib.pyplot as plt
import matplotlib.cm
import numpy as np
import torch
from torch.nn import functional as F
from elektronn3.data.utils import squash01
from elektronn3.data.transforms import RemapTargetIDs
E3_CMAP: str = os.getenv('E3_CMAP')
logger = logging.getLogger('elektronn3log')
[docs]
def get_cmap(out_channels: int):
if E3_CMAP is not None:
cmname = E3_CMAP
# Else, use defaults:
elif out_channels <= 10:
cmname = 'tab10'
elif out_channels <= 20:
cmname = 'tab20'
else:
raise RuntimeError(
f'Default cmaps only support up to 20 colors, which are not enough to label '
f'{out_channels} different output channels.\nPlease set a different cmap '
'with the E3_CMAP envvar.'
)
return matplotlib.cm.get_cmap(cmname, out_channels)
[docs]
def plot_image(
image: np.ndarray,
overlay: Optional[np.ndarray] = None,
overlay_alpha=0.5,
cmap=None,
colorbar=True,
filename=None,
vmin=None,
vmax=None
) -> matplotlib.figure.Figure:
"""Plots a 2D image to a malplotlib figure.
For gray-scale images, use ``cmap='gray'``.
For label matrices (segmentation targets or class predictions),
specify the global number of possible classes in ``out_channels``."""
# Determine colormap and set discrete color values if needed.
ticks = None
ticklabels = None
cmap_name = cmap if isinstance(cmap, str) else cmap.name
if cmap_name in {E3_CMAP, 'tab10', 'tab20'}: # qualitative cmap
ticks = np.linspace(0.5, vmax - 0.5, vmax) # 0.5 for centered ticks
ticklabels = np.arange(vmax)
fig, ax = plt.subplots(constrained_layout=True, figsize=(10, 10))
if image.ndim == 3 and image.shape[0] == 1:
image = image[0]
if overlay is None:
aximg = ax.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none')
else:
ax.imshow(image, cmap='gray')
masked_overlay = np.ma.masked_where(overlay == 0, overlay)
aximg = ax.imshow(masked_overlay, cmap=cmap, vmin=vmin, vmax=vmax, alpha=overlay_alpha, interpolation='none')
if filename is not None:
max_filename_length = 50 # Truncate long file names from the left
if len(filename) > max_filename_length:
filename = f'...{filename[-max_filename_length:]}'
ax.set_title(filename)
if colorbar:
bar = fig.colorbar(aximg, ticks=ticks)
if ticklabels is not None:
bar.set_ticklabels(ticklabels)
bar.solids.set(alpha=1) # otherwise uses image’s opacity
return fig
def _get_batch2img_function(
batch: np.ndarray,
z_plane: Optional[int] = None
) -> Callable[[np.ndarray], np.ndarray]:
"""
Defines ``batch2img`` function dynamically, depending on tensor shapes.
``batch2img`` slices a 4D or 5D tensor to (C, H, W) shape, moves it to
host memory and converts it to a numpy array.
By arbitrary choice, the first element of a batch is always taken here.
In the 5D case, the D (depth) dimension is sliced at z_plane.
This function is useful for plotting image samples during training.
Args:
batch: 4D or 5D tensor, used for shape analysis.
z_plane: Index of the spatial plane where a 5D image tensor should
be sliced. If not specified, this is automatically set to half
the size of the D dimension.
Returns:
Function that slices a plottable 2D image out of a np.ndarray
with batch and channel dimensions.
"""
if batch.ndim == 5: # (N, C, D, H, W)
if z_plane is None:
z_plane = batch.shape[2] // 2
assert z_plane in range(batch.shape[2])
return lambda x: x[0, :, z_plane]
elif batch.ndim == 4: # (N, C, H, W)
return lambda x: x[0, :]
elif batch.ndim == 2: # (N, C) -> img2scalar
return lambda x: x[None, ] # (1, N, C) -> Image will show N x C probabilities
else:
raise ValueError('Only 4D and 5D tensors are supported.')
[docs]
def write_to_kzip(trainer: 'Trainer', pred_batch: np.ndarray) -> None:
from knossos_utils import KnossosDataset
ks = trainer.knossos_preview_config
if isinstance(ks['dataset'], str):
dataset_path = ks['dataset']
else:
dataset_path = ks['dataset'][0]
ds = KnossosDataset(dataset_path)
seg = pred_batch[0].swapaxes(0, 2) # (N, D, H, W) -> (W, H, D)
# Set movement are in k.zip
area_min = ks['offset']
area_sz = ks['size']
anno_str = f"""<?xml version="1.0" encoding="UTF-8"?>
<things>
<parameters>
<MovementArea min.x="{area_min[0]}" min.y="{area_min[1]}" min.z="{area_min[2]}" size.x="{area_sz[0]}" size.y="{area_sz[1]}" size.z="{area_sz[2]}"/>
</parameters>
<comments/>
<branchpoints/>
</things>"""
kzip_path = f'{trainer.save_path}/preview_{trainer.step}.k.zip'
logger.info(f'Writing preview inference to {kzip_path}')
ds.save_to_kzip(
data=seg,
data_mag=ks.get('mag', 1),
kzip_path=kzip_path,
offset=ks['offset'],
mags=ks.get('target_mags', [1, 2]),
gen_mergelist=False,
upsample=False,
fast_resampling=False,
annotation_str=anno_str
)
# TODO: Support regression scenario
def _tb_log_preview(
trainer: 'Trainer', # For some reason Trainer can't be imported
z_plane: Optional[int] = None,
group: str = 'preview'
) -> None:
"""Preview from constant region of preview batch data."""
inp_batch = trainer.preview_batch
out_batch = trainer._preview_inference(
inp=inp_batch,
inference_kwargs=trainer.inference_kwargs,
).to(torch.float32)
inp_batch = inp_batch.numpy()
if trainer.inference_kwargs['apply_softmax']:
out_batch = F.softmax(out_batch, 1).numpy()
else:
out_batch = out_batch.numpy()
batch2img = _get_batch2img_function(out_batch, z_plane)
if inp_batch.ndim == 5 and trainer.enable_videos:
# 5D tensors -> 3D images -> We can make 2D videos out of them
# See comments in the 5D section in _tb_log_sample_images
inp_video = squash01(inp_batch)
# (N, C, T=D, H, W) -> (N, T=D, C, H, W) because of add_video API
inp_video = np.swapaxes(inp_video, 1, 2)
trainer.tb.add_video(
f'{group}_vid/inp', inp_video, global_step=trainer.step
)
for c in range(out_batch.shape[1]):
outc_video = squash01(out_batch[:, c][None]) # Slice C, but keep dimensions intact
# (N, C=1, T=D, H, W) -> (N, T=D, C=1, H, W)
outc_video = np.moveaxis(outc_video, 1, 2)
trainer.tb.add_video(
f'{group}_vid/out{c}',
outc_video,
global_step=trainer.step
)
out_slice = batch2img(out_batch)
pred_slice = out_slice.argmax(0)
if trainer.knossos_preview_config is not None:
pred_batch = out_batch.argmax(1)
remap_ids = trainer.knossos_preview_config.get('remap_ids')
if remap_ids is not None:
remap = RemapTargetIDs(remap_ids, reverse=True)
_, pred_batch = remap(None, pred_batch)
write_to_kzip(trainer, pred_batch)
for c in range(out_slice.shape[0]):
trainer.tb.add_figure(
f'{group}/out{c}',
plot_image(out_slice[c], cmap='gray'),
trainer.step
)
class_cmap = get_cmap(trainer.max_plot_id)
trainer.tb.add_figure(
f'{group}/pred',
plot_image(pred_slice, vmin=0, vmax=trainer.max_plot_id, cmap=class_cmap),
trainer.step
)
inp_slice = batch2img(inp_batch)[0]
trainer.tb.add_figure(
f'{group}/pred_overlay',
plot_image(inp_slice, overlay=pred_slice, overlay_alpha=trainer.overlay_alpha, vmin=0, vmax=trainer.max_plot_id, cmap=class_cmap),
global_step=trainer.step
)
# This is only run once per training, because the ground truth for
# previews is constant (always the same preview inputs/targets)
if trainer._first_plot:
inp_slice = batch2img(trainer.preview_batch)[0]
trainer.tb.add_figure(
f'{group}/inp',
plot_image(inp_slice, cmap='gray'),
global_step=0
)
trainer._first_plot = False
def _tb_log_sample_images(
trainer: 'Trainer',
images: Dict[str, np.ndarray],
z_plane: Optional[int] = None,
group: str = 'sample'
) -> None:
"""Preview from last training/validation sample
Since the images are chosen randomly from the training/validation set
they come from random regions in the data set.
Note: Training images are possibly augmented, so the plots may look
distorted/weirdly colored.
"""
# Always only use the first element of the batch dimension
inp_batch = images['inp'][:1]
target_batch = images.get('target')
if target_batch is not None:
target_batch = target_batch[:1]
out_batch = images['out'][:1]
name = images.get('fname')
if name is not None:
name = name[0]
continuous_cmap = 'viridis'
if trainer.inference_kwargs['apply_softmax']:
out_batch = F.softmax(torch.as_tensor(out_batch, dtype=torch.float32), 1).numpy()
elif trainer.inference_kwargs.get('apply_sigmoid'):
out_batch = torch.sigmoid(torch.as_tensor(out_batch, dtype=torch.float32)).numpy()
else:
out_batch = out_batch.astype(np.float32)
batch2img_inp = _get_batch2img_function(inp_batch, z_plane)
inp_slice = batch2img_inp(images['inp'])
uinp_batch = images.get('unlabeled')
if uinp_batch is not None:
trainer.tb.add_figure(
f'{group}/unlabeled_inp',
plot_image(batch2img_inp(uinp_batch['inp'].cpu().numpy())[0], cmap='gray'),
global_step=trainer.step
)
# TODO: Support one-hot targets
# TODO: Support multi-label targets
# TODO: Output vis missing if target_batch is None
# Check if the network is being trained for classification with class index target tensors
if target_batch is not None:
is_classification = target_batch.ndim == out_batch.ndim - 1
class_cmap = get_cmap(trainer.max_plot_id)
# If it's not classification, we assume a regression scenario
is_regression = np.all(target_batch.shape == out_batch.shape)
# If not exactly one of the scenarios is detected, we can't handle it
assert is_regression != is_classification
if is_classification:
# In classification scenarios, targets have one dim less than network
# outputs, so if we want to use the same batch2img function for
# targets, we have to add an empty channel axis to it after the N dimension
target_batch = target_batch[:, None]
inp_sh = np.array(inp_batch.shape[2:])
out_sh = np.array(out_batch.shape[2:])
if out_batch.shape[2:] != inp_batch.shape[2:] and not (out_batch.ndim == 2):
# Zero-pad output and target to match input shape
# Create a central slice with the size of the output
lo = (inp_sh - out_sh) // 2
hi = inp_sh - lo
slc = tuple([slice(None)] * 2 + [slice(l, h) for l, h in zip(lo, hi)])
padded_out_batch = np.zeros(
(inp_batch.shape[0], out_batch.shape[1], *inp_batch.shape[2:]),
dtype=out_batch.dtype
)
padded_out_batch[slc] = out_batch
out_batch = padded_out_batch
# Assume that target has the same shape as the output and pad it, too
if target_batch is not None:
padded_target_batch = np.zeros((*target_batch.shape[:2], *inp_batch.shape[2:]), dtype=target_batch.dtype)
padded_target_batch[slc] = target_batch
target_batch = padded_target_batch
target_cmap = E3_CMAP
batch2img = _get_batch2img_function(out_batch, z_plane)
if target_batch is not None:
target_slice = batch2img(target_batch)
out_slice = batch2img(out_batch)
if target_batch is not None:
if is_classification:
target_slice = target_slice.squeeze(0) # Squeeze empty axis that was added above
elif target_slice.shape[0] == 3: # Assume RGB values
# RGB images need to be transposed to (H, W, C) layout so matplotlib can handle them
target_slice = np.moveaxis(target_slice, 0, -1) # (C, H, W) -> (H, W, C)
out_slice = np.moveaxis(out_slice, 0, -1)
if inp_batch.ndim == 5 and trainer.enable_videos:
# 5D tensors -> 3D images -> We can make 2D videos out of them
# We re-interpret the D dimension as the temporal dimension T of the video
# -> (N, T, C, H, W)
# Inputs and outputs need to be squashed to the (0, 1) intensity range
# for video rendering, otherwise they will appear as random noise.
# Since tensorboardX's add_video only supports (N, T, C, H, W) tensors,
# we have to add a fake C dimension to the (N, D, H, W) target tensors
# and replace the C dimension of output tensors by empty C dimensions
# to visualize each channel separately.
inp_video = squash01(inp_batch)
# (N, C, T=D, H, W) -> (N, T=D, C, H, W) because of add_video API
inp_video = np.swapaxes(inp_video, 1, 2)
trainer.tb.add_video(
f'{group}_vid/inp', inp_video, global_step=trainer.step
)
if target_batch is not None:
target_video = target_batch
if target_video.ndim == 4:
# TODO: This fails with 2D multi-channel targets. Handle these reliably
target_video = target_video[:, None]
target_video = np.swapaxes(target_video, 1, 2)
trainer.tb.add_video(
f'{group}_vid/target', target_video, global_step=trainer.step
)
for c in range(out_batch.shape[1]):
outc_video = squash01(out_batch[:, c][None]) # Slice C, but keep dimensions intact
# (N, C=1, T=D, H, W) -> (N, T=D, C=1, H, W)
outc_video = np.moveaxis(outc_video, 1, 2)
trainer.tb.add_video(
f'{group}_vid/out{c}',
outc_video,
global_step=trainer.step
)
# TODO: Add output and target overlay videos (not straightforward
# because the 2D overlay code currently uses matplotlib)
for channel in range(inp_slice.shape[0]):
trainer.tb.add_figure(
f'{group}/inp{channel}',
plot_image(inp_slice[channel], cmap='gray', filename=name),
global_step=trainer.step
)
if target_batch is not None:
_out_channels = trainer.max_plot_id if is_classification else None
_cmap = class_cmap if is_classification else continuous_cmap
if target_slice.ndim == 2:
trainer.tb.add_figure(
f'{group}/target',
plot_image(
target_slice, vmin=0, vmax=trainer.max_plot_id, filename=name, cmap=_cmap
# vmin=0., vmax=1.
),
global_step=trainer.step
)
elif target_slice.ndim == 3:
for c in range(target_slice.shape[0]):
trainer.tb.add_figure(
f'{group}/target{c}',
plot_image(
target_slice[c], vmin=0, vmax=trainer.max_plot_id, filename=name, cmap=_cmap
# vmin=0., vmax=1.
),
global_step=trainer.step
)
for key, img in images.items():
if key.startswith('att'):
trainer.tb.add_figure(
f'{group}/{key}',
plot_image(img, cmap='viridis'),
global_step=trainer.step
)
# Plot each output channel c individually as "out{c}"
for c in range(out_slice.shape[0]):
trainer.tb.add_figure(
f'{group}/out{c}',
plot_image(
out_slice[c], cmap=continuous_cmap, filename=name,
# vmin=0., vmax=1.
),
global_step=trainer.step
)
# Only make pred and overlay plots in classification scenarios
if target_batch is not None:
if is_classification:
pred_slice = out_slice.argmax(0)
trainer.tb.add_figure(
f'{group}/pred_slice',
plot_image(pred_slice, vmin=0, vmax=trainer.max_plot_id, cmap=class_cmap, filename=name),
global_step=trainer.step
)
if target_batch is not None and not target_batch.ndim == 2: # TODO: Make this condition more reliable and document it
for c in range(inp_slice.shape[0]):
trainer.tb.add_figure(
f'{group}/target_overlay{c}',
plot_image(inp_slice[c], overlay=target_slice, overlay_alpha=trainer.overlay_alpha, vmin=0, vmax=trainer.max_plot_id, cmap=class_cmap, filename=name),
global_step=trainer.step
)
trainer.tb.add_figure(
f'{group}/pred_overlay',
plot_image(inp_slice[c], overlay=pred_slice, overlay_alpha=trainer.overlay_alpha, vmin=0, vmax=trainer.max_plot_id, cmap=class_cmap, filename=name),
global_step=trainer.step
)
def _tb_log_sample_images_all_img(
trainer: 'Trainer',
images: Dict[str, np.ndarray],
z_plane: Optional[int] = None,
group: str = 'sample'
) -> None:
"""Tensorboard plotting handler that plots all arrays in the ``images``
dict as 2D grayscale images. Multi-channel images are split along the
C dimension and plotted separately.
"""
name = images.pop('fname', [None])[0]
# TODO: Clean up/remove the messy name handling. Figure out how to pass non-image data cleanly.
for key, img in images.items():
img = img[:1] # Always only use the first element of the batch dimension
batch2img = _get_batch2img_function(img, z_plane)
img = batch2img(img)
if img.shape[0] == 1:
trainer.tb.add_figure(
f'{group}/{key}',
plot_image(img[0], cmap='gray', filename=name),
global_step=trainer.step
)
else:
for c in range(img.shape[0]):
trainer.tb.add_figure(
f'{group}/{key}{c}',
plot_image(img[c], cmap='gray', filename=name),
global_step=trainer.step
)