mfrashad's picture
Init code
97069e1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from . import resnet, resnext
try:
from lib.nn import SynchronizedBatchNorm2d
except ImportError:
from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d
class SegmentationModuleBase(nn.Module):
def __init__(self):
super(SegmentationModuleBase, self).__init__()
@staticmethod
def pixel_acc(pred, label, ignore_index=-1):
_, preds = torch.max(pred, dim=1)
valid = (label != ignore_index).long()
acc_sum = torch.sum(valid * (preds == label).long())
pixel_sum = torch.sum(valid)
acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
return acc
@staticmethod
def part_pixel_acc(pred_part, gt_seg_part, gt_seg_object, object_label, valid):
mask_object = (gt_seg_object == object_label)
_, pred = torch.max(pred_part, dim=1)
acc_sum = mask_object * (pred == gt_seg_part)
acc_sum = torch.sum(acc_sum.view(acc_sum.size(0), -1), dim=1)
acc_sum = torch.sum(acc_sum * valid)
pixel_sum = torch.sum(mask_object.view(mask_object.size(0), -1), dim=1)
pixel_sum = torch.sum(pixel_sum * valid)
return acc_sum, pixel_sum
@staticmethod
def part_loss(pred_part, gt_seg_part, gt_seg_object, object_label, valid):
mask_object = (gt_seg_object == object_label)
loss = F.nll_loss(pred_part, gt_seg_part * mask_object.long(), reduction='none')
loss = loss * mask_object.float()
loss = torch.sum(loss.view(loss.size(0), -1), dim=1)
nr_pixel = torch.sum(mask_object.view(mask_object.shape[0], -1), dim=1)
sum_pixel = (nr_pixel * valid).sum()
loss = (loss * valid.float()).sum() / torch.clamp(sum_pixel, 1).float()
return loss
class SegmentationModule(SegmentationModuleBase):
def __init__(self, net_enc, net_dec, labeldata, loss_scale=None):
super(SegmentationModule, self).__init__()
self.encoder = net_enc
self.decoder = net_dec
self.crit_dict = nn.ModuleDict()
if loss_scale is None:
self.loss_scale = {"object": 1, "part": 0.5, "scene": 0.25, "material": 1}
else:
self.loss_scale = loss_scale
# criterion
self.crit_dict["object"] = nn.NLLLoss(ignore_index=0) # ignore background 0
self.crit_dict["material"] = nn.NLLLoss(ignore_index=0) # ignore background 0
self.crit_dict["scene"] = nn.NLLLoss(ignore_index=-1) # ignore unlabelled -1
# Label data - read from json
self.labeldata = labeldata
object_to_num = {k: v for v, k in enumerate(labeldata['object'])}
part_to_num = {k: v for v, k in enumerate(labeldata['part'])}
self.object_part = {object_to_num[k]:
[part_to_num[p] for p in v]
for k, v in labeldata['object_part'].items()}
self.object_with_part = sorted(self.object_part.keys())
self.decoder.object_part = self.object_part
self.decoder.object_with_part = self.object_with_part
def forward(self, feed_dict, *, seg_size=None):
if seg_size is None: # training
if feed_dict['source_idx'] == 0:
output_switch = {"object": True, "part": True, "scene": True, "material": False}
elif feed_dict['source_idx'] == 1:
output_switch = {"object": False, "part": False, "scene": False, "material": True}
else:
raise ValueError
pred = self.decoder(
self.encoder(feed_dict['img'], return_feature_maps=True),
output_switch=output_switch
)
# loss
loss_dict = {}
if pred['object'] is not None: # object
loss_dict['object'] = self.crit_dict['object'](pred['object'], feed_dict['seg_object'])
if pred['part'] is not None: # part
part_loss = 0
for idx_part, object_label in enumerate(self.object_with_part):
part_loss += self.part_loss(
pred['part'][idx_part], feed_dict['seg_part'],
feed_dict['seg_object'], object_label, feed_dict['valid_part'][:, idx_part])
loss_dict['part'] = part_loss
if pred['scene'] is not None: # scene
loss_dict['scene'] = self.crit_dict['scene'](pred['scene'], feed_dict['scene_label'])
if pred['material'] is not None: # material
loss_dict['material'] = self.crit_dict['material'](pred['material'], feed_dict['seg_material'])
loss_dict['total'] = sum([loss_dict[k] * self.loss_scale[k] for k in loss_dict.keys()])
# metric
metric_dict= {}
if pred['object'] is not None:
metric_dict['object'] = self.pixel_acc(
pred['object'], feed_dict['seg_object'], ignore_index=0)
if pred['material'] is not None:
metric_dict['material'] = self.pixel_acc(
pred['material'], feed_dict['seg_material'], ignore_index=0)
if pred['part'] is not None:
acc_sum, pixel_sum = 0, 0
for idx_part, object_label in enumerate(self.object_with_part):
acc, pixel = self.part_pixel_acc(
pred['part'][idx_part], feed_dict['seg_part'], feed_dict['seg_object'],
object_label, feed_dict['valid_part'][:, idx_part])
acc_sum += acc
pixel_sum += pixel
metric_dict['part'] = acc_sum.float() / (pixel_sum.float() + 1e-10)
if pred['scene'] is not None:
metric_dict['scene'] = self.pixel_acc(
pred['scene'], feed_dict['scene_label'], ignore_index=-1)
return {'metric': metric_dict, 'loss': loss_dict}
else: # inference
output_switch = {"object": True, "part": True, "scene": True, "material": True}
pred = self.decoder(self.encoder(feed_dict['img'], return_feature_maps=True),
output_switch=output_switch, seg_size=seg_size)
return pred
def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=has_bias)
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
return nn.Sequential(
conv3x3(in_planes, out_planes, stride),
SynchronizedBatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)
class ModelBuilder:
def __init__(self):
pass
# custom weights initialization
@staticmethod
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
#elif classname.find('Linear') != -1:
# m.weight.data.normal_(0.0, 0.0001)
def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
pretrained = True if len(weights) == 0 else False
if arch == 'resnet34':
raise NotImplementedError
orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet34_dilated8':
raise NotImplementedError
orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=8)
elif arch == 'resnet34_dilated16':
raise NotImplementedError
orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=16)
elif arch == 'resnet50':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet101':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnext101':
orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)
net_encoder = Resnet(orig_resnext) # we can still use class Resnet
else:
raise Exception('Architecture undefined!')
# net_encoder.apply(self.weights_init)
if len(weights) > 0:
# print('Loading weights for net_encoder')
net_encoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
return net_encoder
def build_decoder(self, nr_classes,
arch='ppm_bilinear_deepsup', fc_dim=512,
weights='', use_softmax=False):
if arch == 'upernet_lite':
net_decoder = UPerNet(
nr_classes=nr_classes,
fc_dim=fc_dim,
use_softmax=use_softmax,
fpn_dim=256)
elif arch == 'upernet':
net_decoder = UPerNet(
nr_classes=nr_classes,
fc_dim=fc_dim,
use_softmax=use_softmax,
fpn_dim=512)
else:
raise Exception('Architecture undefined!')
net_decoder.apply(self.weights_init)
if len(weights) > 0:
# print('Loading weights for net_decoder')
net_decoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
return net_decoder
class Resnet(nn.Module):
def __init__(self, orig_resnet):
super(Resnet, self).__init__()
# take pretrained resnet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
self.bn1 = orig_resnet.bn1
self.relu1 = orig_resnet.relu1
self.conv2 = orig_resnet.conv2
self.bn2 = orig_resnet.bn2
self.relu2 = orig_resnet.relu2
self.conv3 = orig_resnet.conv3
self.bn3 = orig_resnet.bn3
self.relu3 = orig_resnet.relu3
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def forward(self, x, return_feature_maps=False):
conv_out = []
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.layer1(x); conv_out.append(x);
x = self.layer2(x); conv_out.append(x);
x = self.layer3(x); conv_out.append(x);
x = self.layer4(x); conv_out.append(x);
if return_feature_maps:
return conv_out
return [x]
# upernet
class UPerNet(nn.Module):
def __init__(self, nr_classes, fc_dim=4096,
use_softmax=False, pool_scales=(1, 2, 3, 6),
fpn_inplanes=(256,512,1024,2048), fpn_dim=256):
# Lazy import so that compilation isn't needed if not being used.
from .prroi_pool import PrRoIPool2D
super(UPerNet, self).__init__()
self.use_softmax = use_softmax
# PPM Module
self.ppm_pooling = []
self.ppm_conv = []
for scale in pool_scales:
# we use the feature map size instead of input image size, so down_scale = 1.0
self.ppm_pooling.append(PrRoIPool2D(scale, scale, 1.))
self.ppm_conv.append(nn.Sequential(
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
self.ppm_conv = nn.ModuleList(self.ppm_conv)
self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)
# FPN Module
self.fpn_in = []
for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
self.fpn_in.append(nn.Sequential(
nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(fpn_dim),
nn.ReLU(inplace=True)
))
self.fpn_in = nn.ModuleList(self.fpn_in)
self.fpn_out = []
for i in range(len(fpn_inplanes) - 1): # skip the top layer
self.fpn_out.append(nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
))
self.fpn_out = nn.ModuleList(self.fpn_out)
self.conv_fusion = conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1)
# background included. if ignore in loss, output channel 0 will not be trained.
self.nr_scene_class, self.nr_object_class, self.nr_part_class, self.nr_material_class = \
nr_classes['scene'], nr_classes['object'], nr_classes['part'], nr_classes['material']
# input: PPM out, input_dim: fpn_dim
self.scene_head = nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(fpn_dim, self.nr_scene_class, kernel_size=1, bias=True)
)
# input: Fusion out, input_dim: fpn_dim
self.object_head = nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
nn.Conv2d(fpn_dim, self.nr_object_class, kernel_size=1, bias=True)
)
# input: Fusion out, input_dim: fpn_dim
self.part_head = nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
nn.Conv2d(fpn_dim, self.nr_part_class, kernel_size=1, bias=True)
)
# input: FPN_2 (P2), input_dim: fpn_dim
self.material_head = nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
nn.Conv2d(fpn_dim, self.nr_material_class, kernel_size=1, bias=True)
)
def forward(self, conv_out, output_switch=None, seg_size=None):
output_dict = {k: None for k in output_switch.keys()}
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
roi = [] # fake rois, just used for pooling
for i in range(input_size[0]): # batch size
roi.append(torch.Tensor([i, 0, 0, input_size[3], input_size[2]]).view(1, -1)) # b, x0, y0, x1, y1
roi = torch.cat(roi, dim=0).type_as(conv5)
ppm_out = [conv5]
for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
ppm_out.append(pool_conv(F.interpolate(
pool_scale(conv5, roi.detach()),
(input_size[2], input_size[3]),
mode='bilinear', align_corners=False)))
ppm_out = torch.cat(ppm_out, 1)
f = self.ppm_last_conv(ppm_out)
if output_switch['scene']: # scene
output_dict['scene'] = self.scene_head(f)
if output_switch['object'] or output_switch['part'] or output_switch['material']:
fpn_feature_list = [f]
for i in reversed(range(len(conv_out) - 1)):
conv_x = conv_out[i]
conv_x = self.fpn_in[i](conv_x) # lateral branch
f = F.interpolate(
f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch
f = conv_x + f
fpn_feature_list.append(self.fpn_out[i](f))
fpn_feature_list.reverse() # [P2 - P5]
# material
if output_switch['material']:
output_dict['material'] = self.material_head(fpn_feature_list[0])
if output_switch['object'] or output_switch['part']:
output_size = fpn_feature_list[0].size()[2:]
fusion_list = [fpn_feature_list[0]]
for i in range(1, len(fpn_feature_list)):
fusion_list.append(F.interpolate(
fpn_feature_list[i],
output_size,
mode='bilinear', align_corners=False))
fusion_out = torch.cat(fusion_list, 1)
x = self.conv_fusion(fusion_out)
if output_switch['object']: # object
output_dict['object'] = self.object_head(x)
if output_switch['part']:
output_dict['part'] = self.part_head(x)
if self.use_softmax: # is True during inference
# inference scene
x = output_dict['scene']
x = x.squeeze(3).squeeze(2)
x = F.softmax(x, dim=1)
output_dict['scene'] = x
# inference object, material
for k in ['object', 'material']:
x = output_dict[k]
x = F.interpolate(x, size=seg_size, mode='bilinear', align_corners=False)
x = F.softmax(x, dim=1)
output_dict[k] = x
# inference part
x = output_dict['part']
x = F.interpolate(x, size=seg_size, mode='bilinear', align_corners=False)
part_pred_list, head = [], 0
for idx_part, object_label in enumerate(self.object_with_part):
n_part = len(self.object_part[object_label])
_x = F.interpolate(x[:, head: head + n_part], size=seg_size, mode='bilinear', align_corners=False)
_x = F.softmax(_x, dim=1)
part_pred_list.append(_x)
head += n_part
output_dict['part'] = part_pred_list
else: # Training
# object, scene, material
for k in ['object', 'scene', 'material']:
if output_dict[k] is None:
continue
x = output_dict[k]
x = F.log_softmax(x, dim=1)
if k == "scene": # for scene
x = x.squeeze(3).squeeze(2)
output_dict[k] = x
if output_dict['part'] is not None:
part_pred_list, head = [], 0
for idx_part, object_label in enumerate(self.object_with_part):
n_part = len(self.object_part[object_label])
x = output_dict['part'][:, head: head + n_part]
x = F.log_softmax(x, dim=1)
part_pred_list.append(x)
head += n_part
output_dict['part'] = part_pred_list
return output_dict