Source code for elektronn3.models.fcn

# ELEKTRONN3 - Neural Network Toolkit
#
# Copyright (c) 2017 - now
# Max Planck Institute of Neurobiology, Munich, Germany

"""
adopted from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/fcn.py
LICENSE https://github.com/meetshah1995/pytorch-semseg/blob/master/LICENSE
#In every layer few steps have been commented because of Memory constraints (please uncomment them acc to the resources)
"""


import torch.nn as nn
import torch.nn.functional as F


[docs] class fcn32s(nn.Module): def __init__(self, n_classes=2, learned_billinear=False, red_fac=16): super(fcn32s, self).__init__() self.learned_billinear = learned_billinear self.n_classes = n_classes self.conv_block1 = nn.Sequential( nn.Conv3d(1, 64 // red_fac, 3, padding=100), nn.ReLU(inplace=True), # nn.Conv3d(64 // red_fac, 64 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block2 = nn.Sequential( nn.Conv3d(64 // red_fac, 128 // red_fac, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(128 // red_fac, 128 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block3 = nn.Sequential( nn.Conv3d(128 // red_fac, 256 // red_fac, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(256 // red_fac, 256 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(256 // red_fac, 256 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block4 = nn.Sequential( nn.Conv3d(256 // red_fac, 512 // red_fac, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(512 // red_fac, 512 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(512 // red_fac, 512 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block5 = nn.Sequential( nn.Conv3d(512 // red_fac, 512 // red_fac, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(512 // red_fac, 512 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(512 // red_fac, 512 // red_fac, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.classifier = nn.Sequential( nn.Conv3d(512 // red_fac, 4096 // red_fac, 7), nn.ReLU(inplace=True), nn.Dropout3d(), # nn.Conv3d(4096 // red_fac, 4096 // red_fac, 1), # nn.ReLU(inplace=True), # nn.Dropout3d(), nn.Conv3d(4096 // red_fac, self.n_classes, 1),) # TODO: Add support for learned upsampling if self.learned_billinear: raise NotImplementedError # upscore = nn.ConvTranspose3d(self.n_classes, self.n_classes, 64, stride=32, bias=False) # upscore.scale_factor = None
[docs] def forward(self, x): out = self.conv_block1(x) out = self.conv_block2(out) out = self.conv_block3(out) out = self.conv_block4(out) out = self.conv_block5(out) score = self.classifier(out) out = F.upsample(score, x.size()[2:], mode='trilinear') return out
[docs] def init_vgg16_params(self, vgg16, copy_fc8=True): blocks = [self.conv_block1, self.conv_block2, self.conv_block3, self.conv_block4, self.conv_block5] ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] features = list(vgg16.features.children()) for idx, conv_block in enumerate(blocks): for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): if isinstance(l1, nn.Conv3d) and isinstance(l2, nn.Conv3d): # print idx, l1, l2 assert l1.weight.size() == l2.weight.size() assert l1.bias.size() == l2.bias.size() l2.weight.data = l1.weight.data l2.bias.data = l1.bias.data for i1, i2 in zip([0, 3], [0, 3]): l1 = vgg16.classifier[i1] l2 = self.classifier[i2] # print type(l1), dir(l1), l2.weight.data = l1.weight.data.view(l2.weight.size()) l2.bias.data = l1.bias.data.view(l2.bias.size()) n_class = self.classifier[6].weight.size()[0] if copy_fc8: l1 = vgg16.classifier[6] l2 = self.classifier[6] l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size()) l2.bias.data = l1.bias.data[:n_class]
[docs] class fcn16s(nn.Module): def __init__(self, n_classes=2, learned_billinear=False): super(fcn16s, self).__init__() self.learned_billinear = learned_billinear self.n_classes = n_classes self.conv_block1 = nn.Sequential( nn.Conv3d(1, 64, 3, padding=100), nn.ReLU(inplace=True), # nn.Conv3d(64, 64, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block2 = nn.Sequential( nn.Conv3d(64, 128, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(128, 128, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block3 = nn.Sequential( nn.Conv3d(128, 256, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(256, 256, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(256, 256, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block4 = nn.Sequential( nn.Conv3d(256, 512, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block5 = nn.Sequential( nn.Conv3d(512, 512, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.classifier = nn.Sequential( nn.Conv3d(512, 4096, 7), nn.ReLU(inplace=True), nn.Dropout3d(), # nn.Conv3d(4096, 4096, 1), # nn.ReLU(inplace=True), # nn.Dropout3d(), nn.Conv3d(4096, self.n_classes, 1),) self.score_pool4 = nn.Conv3d(512, self.n_classes, 1) # TODO: Add support for learned upsampling if self.learned_billinear: raise NotImplementedError # upscore = nn.ConvTranspose3d(self.n_classes, self.n_classes, 64, stride=32, bias=False) # upscore.scale_factor = None
[docs] def forward(self, x): conv1 = self.conv_block1(x) conv2 = self.conv_block2(conv1) conv3 = self.conv_block3(conv2) conv4 = self.conv_block4(conv3) conv5 = self.conv_block5(conv4) score = self.classifier(conv5) score_pool4 = self.score_pool4(conv4) score = F.upsample(score, score_pool4.size()[2:], mode='trilinear') score += score_pool4 out = F.upsample(score, x.size()[2:], mode='trilinear') return out
[docs] def init_vgg16_params(self, vgg16, copy_fc8=True): blocks = [self.conv_block1, self.conv_block2, self.conv_block3, self.conv_block4, self.conv_block5] ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] features = list(vgg16.features.children()) for idx, conv_block in enumerate(blocks): for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): if isinstance(l1, nn.Conv3d) and isinstance(l2, nn.Conv3d): # print idx, l1, l2 assert l1.weight.size() == l2.weight.size() assert l1.bias.size() == l2.bias.size() l2.weight.data = l1.weight.data l2.bias.data = l1.bias.data for i1, i2 in zip([0, 3], [0, 3]): l1 = vgg16.classifier[i1] l2 = self.classifier[i2] l2.weight.data = l1.weight.data.view(l2.weight.size()) l2.bias.data = l1.bias.data.view(l2.bias.size()) n_class = self.classifier[6].weight.size()[0] if copy_fc8: l1 = vgg16.classifier[6] l2 = self.classifier[6] l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size()) l2.bias.data = l1.bias.data[:n_class]
# FCN 8s
[docs] class fcn8s(nn.Module): def __init__(self, n_classes=2, learned_billinear=False): super(fcn8s, self).__init__() self.learned_billinear = learned_billinear self.n_classes = n_classes self.conv_block1 = nn.Sequential( nn.Conv3d(1, 64, 3, padding=100), nn.ReLU(inplace=True), # nn.Conv3d(64, 64, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block2 = nn.Sequential( nn.Conv3d(64, 128, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(128, 128, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block3 = nn.Sequential( nn.Conv3d(128, 256, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(256, 256, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(256, 256, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block4 = nn.Sequential( nn.Conv3d(256, 512, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.conv_block5 = nn.Sequential( nn.Conv3d(512, 512, 3, padding=1), nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), # nn.Conv3d(512, 512, 3, padding=1), # nn.ReLU(inplace=True), nn.MaxPool3d(2, stride=2, ceil_mode=True),) self.classifier = nn.Sequential( nn.Conv3d(512, 4096, 7), nn.ReLU(inplace=True), nn.Dropout3d(), # nn.Conv3d(4096, 4096, 1), # nn.ReLU(inplace=True), # nn.Dropout3d(), nn.Conv3d(4096, self.n_classes, 1),) self.score_pool4 = nn.Conv3d(512, self.n_classes, 1) self.score_pool3 = nn.Conv3d(256, self.n_classes, 1) # TODO: Add support for learned upsampling if self.learned_billinear: raise NotImplementedError # upscore = nn.ConvTranspose3d(self.n_classes, self.n_classes, 64, stride=32, bias=False) # upscore.scale_factor = None
[docs] def forward(self, x): conv1 = self.conv_block1(x) conv2 = self.conv_block2(conv1) conv3 = self.conv_block3(conv2) conv4 = self.conv_block4(conv3) conv5 = self.conv_block5(conv4) score = self.classifier(conv5) score_pool4 = self.score_pool4(conv4) score_pool3 = self.score_pool3(conv3) score = F.upsample(score, score_pool4.size()[2:], mode='trilinear') score += score_pool4 score = F.upsample(score, score_pool3.size()[2:], mode='trilinear') score += score_pool3 out = F.upsample(score, x.size()[2:], mode='trilinear') return out
[docs] def init_vgg16_params(self, vgg16, copy_fc8=True): blocks = [self.conv_block1, self.conv_block2, self.conv_block3, self.conv_block4, self.conv_block5] ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] features = list(vgg16.features.children()) for idx, conv_block in enumerate(blocks): for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): if isinstance(l1, nn.Conv3d) and isinstance(l2, nn.Conv3d): assert l1.weight.size() == l2.weight.size() assert l1.bias.size() == l2.bias.size() l2.weight.data = l1.weight.data l2.bias.data = l1.bias.data for i1, i2 in zip([0, 3], [0, 3]): l1 = vgg16.classifier[i1] l2 = self.classifier[i2] l2.weight.data = l1.weight.data.view(l2.weight.size()) l2.bias.data = l1.bias.data.view(l2.bias.size()) n_class = self.classifier[6].weight.size()[0] if copy_fc8: l1 = vgg16.classifier[6] l2 = self.classifier[6] l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size()) l2.bias.data = l1.bias.data[:n_class]