Source code for elektronn3.models.unet3d_lite

from typing import Tuple

import torch
from torch import nn
from torch.nn import functional as F


[docs] class PoolingError(Exception): pass
[docs] class UNet3dLite(nn.Module): """(WIP) Re-implementation of the unet3d_lite model from ELEKTRONN2 See https://github.com/ELEKTRONN/ELEKTRONN2/blob/master/examples/unet3d_lite.py Pay attention to shapes: Only spatial input shape (22, 140, 140) is supported. fov=[12, 88, 88], offsets=[6, 44, 44], strides=[1 1 1], spatial shape=[10, 52, 52] This model is directly compatible with torch.jit.script. """ def __init__(self): super().__init__() # self.down = nn.MaxPool3d((1, 2, 2)) self.conv0 = nn.Conv3d(1, 32, (1, 3, 3)) self.conv1 = nn.Conv3d(32, 32, (1, 3, 3)) self.conv2 = nn.Conv3d(32, 64, (1, 3, 3)) self.conv3 = nn.Conv3d(64, 64, (1, 3, 3)) self.conv4 = nn.Conv3d(64, 128, (1, 3, 3)) self.conv5 = nn.Conv3d(128, 128, (1, 3, 3)) self.conv6 = nn.Conv3d(128, 256, (3, 3, 3)) self.conv7 = nn.Conv3d(256, 128, (3, 3, 3)) self.upconv0 = nn.ConvTranspose3d(128, 512, (1, 2, 2), (1, 2, 2)) self.mconv0 = nn.Conv3d(640, 256, (1, 3, 3)) self.mconv1 = nn.Conv3d(256, 64, (1, 3, 3)) self.upconv1 = nn.ConvTranspose3d(64, 256, (1, 2, 2), (1, 2, 2)) self.mconv2 = nn.Conv3d(320, 128, (3, 3, 3)) self.mconv3 = nn.Conv3d(128, 32, (3, 3, 3)) self.upconv2 = nn.ConvTranspose3d(32, 128, (1, 2, 2), (1, 2, 2)) self.mconv4 = nn.Conv3d(160, 64, (3, 3, 3)) self.mconv5 = nn.Conv3d(64, 64, (3, 3, 3)) self.conv_final = nn.Conv3d(64, 2, 1)
[docs] @staticmethod def autocrop(from_down: torch.Tensor, from_up: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ds = from_down.shape[2:] us = from_up.shape[2:] from_down = from_down[ :, :, ((ds[0] - us[0]) // 2):((ds[0] + us[0]) // 2), ((ds[1] - us[1]) // 2):((ds[1] + us[1]) // 2), ((ds[2] - us[2]) // 2):((ds[2] + us[2]) // 2), ] return from_down, from_up
[docs] @staticmethod def down(x, ks=(1, 2, 2)): # Before pooling, we manually check if the tensor is divisible by the pooling kernel # size, because PyTorch doesn't throw an error if it's not divisible, but calculates # the output shape by floor division instead. While this may make sense for other # architectures, in U-Net this would lead to incorrect output shapes after upsampling. sh = x.shape[2:] if any([s % k != 0 for s, k in zip(sh, ks)]): raise PoolingError( f'Can\'t pool {sh} input by a {ks} kernel. Please adjust the input shape.' ) return F.max_pool3d(x, ks)
[docs] def forward(self, inp): conv0 = F.relu(self.conv0(inp)) conv1 = F.relu(self.conv1(conv0)) down0 = self.down(conv1) conv2 = F.relu(self.conv2(down0)) conv3 = F.relu(self.conv3(conv2)) down1 = self.down(conv3) conv4 = F.relu(self.conv4(down1)) conv5 = F.relu(self.conv5(conv4)) down2 = self.down(conv5) conv6 = F.relu(self.conv6(down2)) conv7 = F.relu(self.conv7(conv6)) upconv0 = F.relu(self.upconv0(conv7)) d0, u0 = self.autocrop(conv5, upconv0) mrg0 = torch.cat((d0, u0), 1) mconv0 = F.relu(self.mconv0(mrg0)) mconv1 = F.relu(self.mconv1(mconv0)) upconv1 = F.relu(self.upconv1(mconv1)) d1, u1 = self.autocrop(conv3, upconv1) mrg1 = torch.cat((d1, u1), 1) mconv2 = F.relu(self.mconv2(mrg1)) mconv3 = F.relu(self.mconv3(mconv2)) upconv2 = F.relu(self.upconv2(mconv3)) d2, u2 = self.autocrop(conv1, upconv2) mrg2 = torch.cat((d2, u2), 1) mconv4 = F.relu((self.mconv4(mrg2))) mconv5 = F.relu(self.mconv5(mconv4)) out = self.conv_final(mconv5) return out
if __name__ == '__main__': m = UNet3dLite() x = torch.randn(1, 1, 22, 140, 140) y = m(x) print(y.shape)