"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" import os import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from scipy.io import loadmat from torch.nn.modules import BatchNorm2d from . import resnet from . import mobilenet NUM_CLASS = 150 base_path = os.path.dirname(os.path.abspath(__file__)) # current file path colors_path = os.path.join(base_path, 'color150.mat') classes_path = os.path.join(base_path, 'object150_info.csv') segm_options = dict(colors=loadmat(colors_path)['colors'], classes=pd.read_csv(classes_path),) class NormalizeTensor: def __init__(self, mean, std, inplace=False): """Normalize a tensor image with mean and standard deviation. .. note:: This transform acts out of place by default, i.e., it does not mutates the input tensor. See :class:`~torchvision.transforms.Normalize` for more details. Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. inplace(bool,optional): Bool to make this operation inplace. Returns: Tensor: Normalized Tensor image. """ self.mean = mean self.std = std self.inplace = inplace def __call__(self, tensor): if not self.inplace: tensor = tensor.clone() dtype = tensor.dtype mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) return tensor # Model Builder class ModelBuilder: # custom weights initialization @staticmethod def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.kaiming_normal_(m.weight.data) elif classname.find('BatchNorm') != -1: m.weight.data.fill_(1.) m.bias.data.fill_(1e-4) @staticmethod def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): pretrained = True if len(weights) == 0 else False arch = arch.lower() if arch == 'mobilenetv2dilated': orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) elif arch == 'resnet18': orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) net_encoder = Resnet(orig_resnet) elif arch == 'resnet18dilated': orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) elif arch == 'resnet50dilated': orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) elif arch == 'resnet50': orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) net_encoder = Resnet(orig_resnet) else: raise Exception('Architecture undefined!') # encoders are usually pretrained # net_encoder.apply(ModelBuilder.weights_init) if len(weights) > 0: print('Loading weights for net_encoder') net_encoder.load_state_dict( torch.load(weights, map_location=lambda storage, loc: storage), strict=False) return net_encoder @staticmethod def build_decoder(arch='ppm_deepsup', fc_dim=512, num_class=NUM_CLASS, weights='', use_softmax=False, drop_last_conv=False): arch = arch.lower() if arch == 'ppm_deepsup': net_decoder = PPMDeepsup( num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, drop_last_conv=drop_last_conv) elif arch == 'c1_deepsup': net_decoder = C1DeepSup( num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, drop_last_conv=drop_last_conv) else: raise Exception('Architecture undefined!') net_decoder.apply(ModelBuilder.weights_init) if len(weights) > 0: print('Loading weights for net_decoder') net_decoder.load_state_dict( torch.load(weights, map_location=lambda storage, loc: storage), strict=False) return net_decoder @staticmethod def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs): path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv) @staticmethod def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation, *arts, **kwargs): if segmentation: path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') else: path = '' return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path) def conv3x3_bn_relu(in_planes, out_planes, stride=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), BatchNorm2d(out_planes), nn.ReLU(inplace=True), ) class SegmentationModule(nn.Module): def __init__(self, weights_path, num_classes=150, arch_encoder="resnet50dilated", drop_last_conv=False, net_enc=None, # None for Default encoder net_dec=None, # None for Default decoder encode=None, # {None, 'binary', 'color', 'sky'} use_default_normalization=False, return_feature_maps=False, return_feature_maps_level=3, # {0, 1, 2, 3} return_feature_maps_only=True, **kwargs, ): super().__init__() self.weights_path = weights_path self.drop_last_conv = drop_last_conv self.arch_encoder = arch_encoder if self.arch_encoder == "resnet50dilated": self.arch_decoder = "ppm_deepsup" self.fc_dim = 2048 elif self.arch_encoder == "mobilenetv2dilated": self.arch_decoder = "c1_deepsup" self.fc_dim = 320 else: raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}") model_builder_kwargs = dict(arch_encoder=self.arch_encoder, arch_decoder=self.arch_decoder, fc_dim=self.fc_dim, drop_last_conv=drop_last_conv, weights_path=self.weights_path) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec self.use_default_normalization = use_default_normalization self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.encode = encode self.return_feature_maps = return_feature_maps assert 0 <= return_feature_maps_level <= 3 self.return_feature_maps_level = return_feature_maps_level def normalize_input(self, tensor): if tensor.min() < 0 or tensor.max() > 1: raise ValueError("Tensor should be 0..1 before using normalize_input") return self.default_normalization(tensor) @property def feature_maps_channels(self): return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048 def forward(self, img_data, segSize=None): if segSize is None: raise NotImplementedError("Please pass segSize param. By default: (300, 300)") fmaps = self.encoder(img_data, return_feature_maps=True) pred = self.decoder(fmaps, segSize=segSize) if self.return_feature_maps: return pred, fmaps # print("BINARY", img_data.shape, pred.shape) return pred def multi_mask_from_multiclass(self, pred, classes): def isin(ar1, ar2): return (ar1[..., None] == ar2).any(-1).float() return isin(pred, torch.LongTensor(classes).to(self.device)) @staticmethod def multi_mask_from_multiclass_probs(scores, classes): res = None for c in classes: if res is None: res = scores[:, c] else: res += scores[:, c] return res def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600) segSize=None): """Entry-point for segmentation. Use this methods instead of forward Arguments: tensor {torch.Tensor} -- BCHW Keyword Arguments: imgSizes {tuple or list} -- imgSizes for segmentation input. default: (300, 450) original implementation: (300, 375, 450, 525, 600) """ if segSize is None: segSize = tensor.shape[-2:] segSize = (tensor.shape[2], tensor.shape[3]) with torch.no_grad(): if self.use_default_normalization: tensor = self.normalize_input(tensor) scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device) features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device) result = [] for img_size in imgSizes: if img_size != -1: img_data = F.interpolate(tensor.clone(), size=img_size) else: img_data = tensor.clone() if self.return_feature_maps: pred_current, fmaps = self.forward(img_data, segSize=segSize) else: pred_current = self.forward(img_data, segSize=segSize) result.append(pred_current) scores = scores + pred_current / len(imgSizes) # Disclaimer: We use and aggregate only last fmaps: fmaps[3] if self.return_feature_maps: features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes) _, pred = torch.max(scores, dim=1) if self.return_feature_maps: return features return pred, result def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) if True: return edge.half() return edge.float() # pyramid pooling, deep supervision class PPMDeepsup(nn.Module): def __init__(self, num_class=NUM_CLASS, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6), drop_last_conv=False): super().__init__() self.use_softmax = use_softmax self.drop_last_conv = drop_last_conv self.ppm = [] for scale in pool_scales: self.ppm.append(nn.Sequential( nn.AdaptiveAvgPool2d(scale), nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), BatchNorm2d(512), nn.ReLU(inplace=True) )) self.ppm = nn.ModuleList(self.ppm) self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) self.conv_last = nn.Sequential( nn.Conv2d(fc_dim + len(pool_scales) * 512, 512, kernel_size=3, padding=1, bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1) ) self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) self.dropout_deepsup = nn.Dropout2d(0.1) def forward(self, conv_out, segSize=None): conv5 = conv_out[-1] input_size = conv5.size() ppm_out = [conv5] for pool_scale in self.ppm: ppm_out.append(nn.functional.interpolate( pool_scale(conv5), (input_size[2], input_size[3]), mode='bilinear', align_corners=False)) ppm_out = torch.cat(ppm_out, 1) if self.drop_last_conv: return ppm_out else: x = self.conv_last(ppm_out) if self.use_softmax: # is True during inference x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) x = nn.functional.softmax(x, dim=1) return x # deep sup conv4 = conv_out[-2] _ = self.cbr_deepsup(conv4) _ = self.dropout_deepsup(_) _ = self.conv_last_deepsup(_) x = nn.functional.log_softmax(x, dim=1) _ = nn.functional.log_softmax(_, dim=1) return (x, _) class Resnet(nn.Module): def __init__(self, orig_resnet): super(Resnet, self).__init__() # take pretrained resnet, except AvgPool and FC self.conv1 = orig_resnet.conv1 self.bn1 = orig_resnet.bn1 self.relu1 = orig_resnet.relu1 self.conv2 = orig_resnet.conv2 self.bn2 = orig_resnet.bn2 self.relu2 = orig_resnet.relu2 self.conv3 = orig_resnet.conv3 self.bn3 = orig_resnet.bn3 self.relu3 = orig_resnet.relu3 self.maxpool = orig_resnet.maxpool self.layer1 = orig_resnet.layer1 self.layer2 = orig_resnet.layer2 self.layer3 = orig_resnet.layer3 self.layer4 = orig_resnet.layer4 def forward(self, x, return_feature_maps=False): conv_out = [] x = self.relu1(self.bn1(self.conv1(x))) x = self.relu2(self.bn2(self.conv2(x))) x = self.relu3(self.bn3(self.conv3(x))) x = self.maxpool(x) x = self.layer1(x); conv_out.append(x); x = self.layer2(x); conv_out.append(x); x = self.layer3(x); conv_out.append(x); x = self.layer4(x); conv_out.append(x); if return_feature_maps: return conv_out return [x] # Resnet Dilated class ResnetDilated(nn.Module): def __init__(self, orig_resnet, dilate_scale=8): super().__init__() from functools import partial if dilate_scale == 8: orig_resnet.layer3.apply( partial(self._nostride_dilate, dilate=2)) orig_resnet.layer4.apply( partial(self._nostride_dilate, dilate=4)) elif dilate_scale == 16: orig_resnet.layer4.apply( partial(self._nostride_dilate, dilate=2)) # take pretrained resnet, except AvgPool and FC self.conv1 = orig_resnet.conv1 self.bn1 = orig_resnet.bn1 self.relu1 = orig_resnet.relu1 self.conv2 = orig_resnet.conv2 self.bn2 = orig_resnet.bn2 self.relu2 = orig_resnet.relu2 self.conv3 = orig_resnet.conv3 self.bn3 = orig_resnet.bn3 self.relu3 = orig_resnet.relu3 self.maxpool = orig_resnet.maxpool self.layer1 = orig_resnet.layer1 self.layer2 = orig_resnet.layer2 self.layer3 = orig_resnet.layer3 self.layer4 = orig_resnet.layer4 def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: # the convolution with stride if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate // 2, dilate // 2) m.padding = (dilate // 2, dilate // 2) # other convoluions else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x, return_feature_maps=False): conv_out = [] x = self.relu1(self.bn1(self.conv1(x))) x = self.relu2(self.bn2(self.conv2(x))) x = self.relu3(self.bn3(self.conv3(x))) x = self.maxpool(x) x = self.layer1(x) conv_out.append(x) x = self.layer2(x) conv_out.append(x) x = self.layer3(x) conv_out.append(x) x = self.layer4(x) conv_out.append(x) if return_feature_maps: return conv_out return [x] class MobileNetV2Dilated(nn.Module): def __init__(self, orig_net, dilate_scale=8): super(MobileNetV2Dilated, self).__init__() from functools import partial # take pretrained mobilenet features self.features = orig_net.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if dilate_scale == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif dilate_scale == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: # the convolution with stride if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) # other convoluions else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x, return_feature_maps=False): if return_feature_maps: conv_out = [] for i in range(self.total_idx): x = self.features[i](x) if i in self.down_idx: conv_out.append(x) conv_out.append(x) return conv_out else: return [self.features(x)] # last conv, deep supervision class C1DeepSup(nn.Module): def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False): super(C1DeepSup, self).__init__() self.use_softmax = use_softmax self.drop_last_conv = drop_last_conv self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) # last conv self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) def forward(self, conv_out, segSize=None): conv5 = conv_out[-1] x = self.cbr(conv5) if self.drop_last_conv: return x else: x = self.conv_last(x) if self.use_softmax: # is True during inference x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) x = nn.functional.softmax(x, dim=1) return x # deep sup conv4 = conv_out[-2] _ = self.cbr_deepsup(conv4) _ = self.conv_last_deepsup(_) x = nn.functional.log_softmax(x, dim=1) _ = nn.functional.log_softmax(_, dim=1) return (x, _) # last conv class C1(nn.Module): def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): super(C1, self).__init__() self.use_softmax = use_softmax self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) # last conv self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) def forward(self, conv_out, segSize=None): conv5 = conv_out[-1] x = self.cbr(conv5) x = self.conv_last(x) if self.use_softmax: # is True during inference x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) x = nn.functional.softmax(x, dim=1) else: x = nn.functional.log_softmax(x, dim=1) return x # pyramid pooling class PPM(nn.Module): def __init__(self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6)): super(PPM, self).__init__() self.use_softmax = use_softmax self.ppm = [] for scale in pool_scales: self.ppm.append(nn.Sequential( nn.AdaptiveAvgPool2d(scale), nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), BatchNorm2d(512), nn.ReLU(inplace=True) )) self.ppm = nn.ModuleList(self.ppm) self.conv_last = nn.Sequential( nn.Conv2d(fc_dim+len(pool_scales)*512, 512, kernel_size=3, padding=1, bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1) ) def forward(self, conv_out, segSize=None): conv5 = conv_out[-1] input_size = conv5.size() ppm_out = [conv5] for pool_scale in self.ppm: ppm_out.append(nn.functional.interpolate( pool_scale(conv5), (input_size[2], input_size[3]), mode='bilinear', align_corners=False)) ppm_out = torch.cat(ppm_out, 1) x = self.conv_last(ppm_out) if self.use_softmax: # is True during inference x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) x = nn.functional.softmax(x, dim=1) else: x = nn.functional.log_softmax(x, dim=1) return x