Spaces:
Running
Running
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" | |
import os | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from scipy.io import loadmat | |
from torch.nn.modules import BatchNorm2d | |
from . import resnet | |
from . import mobilenet | |
NUM_CLASS = 150 | |
base_path = os.path.dirname(os.path.abspath(__file__)) # current file path | |
colors_path = os.path.join(base_path, 'color150.mat') | |
classes_path = os.path.join(base_path, 'object150_info.csv') | |
segm_options = dict(colors=loadmat(colors_path)['colors'], | |
classes=pd.read_csv(classes_path),) | |
class NormalizeTensor: | |
def __init__(self, mean, std, inplace=False): | |
"""Normalize a tensor image with mean and standard deviation. | |
.. note:: | |
This transform acts out of place by default, i.e., it does not mutates the input tensor. | |
See :class:`~torchvision.transforms.Normalize` for more details. | |
Args: | |
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. | |
mean (sequence): Sequence of means for each channel. | |
std (sequence): Sequence of standard deviations for each channel. | |
inplace(bool,optional): Bool to make this operation inplace. | |
Returns: | |
Tensor: Normalized Tensor image. | |
""" | |
self.mean = mean | |
self.std = std | |
self.inplace = inplace | |
def __call__(self, tensor): | |
if not self.inplace: | |
tensor = tensor.clone() | |
dtype = tensor.dtype | |
mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) | |
std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) | |
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) | |
return tensor | |
# Model Builder | |
class ModelBuilder: | |
# custom weights initialization | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
nn.init.kaiming_normal_(m.weight.data) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.fill_(1.) | |
m.bias.data.fill_(1e-4) | |
def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): | |
pretrained = True if len(weights) == 0 else False | |
arch = arch.lower() | |
if arch == 'mobilenetv2dilated': | |
orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) | |
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) | |
elif arch == 'resnet18': | |
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) | |
net_encoder = Resnet(orig_resnet) | |
elif arch == 'resnet18dilated': | |
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) | |
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) | |
elif arch == 'resnet50dilated': | |
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) | |
elif arch == 'resnet50': | |
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
net_encoder = Resnet(orig_resnet) | |
else: | |
raise Exception('Architecture undefined!') | |
# encoders are usually pretrained | |
# net_encoder.apply(ModelBuilder.weights_init) | |
if len(weights) > 0: | |
print('Loading weights for net_encoder') | |
net_encoder.load_state_dict( | |
torch.load(weights, map_location=lambda storage, loc: storage), strict=False) | |
return net_encoder | |
def build_decoder(arch='ppm_deepsup', | |
fc_dim=512, num_class=NUM_CLASS, | |
weights='', use_softmax=False, drop_last_conv=False): | |
arch = arch.lower() | |
if arch == 'ppm_deepsup': | |
net_decoder = PPMDeepsup( | |
num_class=num_class, | |
fc_dim=fc_dim, | |
use_softmax=use_softmax, | |
drop_last_conv=drop_last_conv) | |
elif arch == 'c1_deepsup': | |
net_decoder = C1DeepSup( | |
num_class=num_class, | |
fc_dim=fc_dim, | |
use_softmax=use_softmax, | |
drop_last_conv=drop_last_conv) | |
else: | |
raise Exception('Architecture undefined!') | |
net_decoder.apply(ModelBuilder.weights_init) | |
if len(weights) > 0: | |
print('Loading weights for net_decoder') | |
net_decoder.load_state_dict( | |
torch.load(weights, map_location=lambda storage, loc: storage), strict=False) | |
return net_decoder | |
def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs): | |
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') | |
return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv) | |
def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation, | |
*arts, **kwargs): | |
if segmentation: | |
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') | |
else: | |
path = '' | |
return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path) | |
def conv3x3_bn_relu(in_planes, out_planes, stride=1): | |
return nn.Sequential( | |
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), | |
BatchNorm2d(out_planes), | |
nn.ReLU(inplace=True), | |
) | |
class SegmentationModule(nn.Module): | |
def __init__(self, | |
weights_path, | |
num_classes=150, | |
arch_encoder="resnet50dilated", | |
drop_last_conv=False, | |
net_enc=None, # None for Default encoder | |
net_dec=None, # None for Default decoder | |
encode=None, # {None, 'binary', 'color', 'sky'} | |
use_default_normalization=False, | |
return_feature_maps=False, | |
return_feature_maps_level=3, # {0, 1, 2, 3} | |
return_feature_maps_only=True, | |
**kwargs, | |
): | |
super().__init__() | |
self.weights_path = weights_path | |
self.drop_last_conv = drop_last_conv | |
self.arch_encoder = arch_encoder | |
if self.arch_encoder == "resnet50dilated": | |
self.arch_decoder = "ppm_deepsup" | |
self.fc_dim = 2048 | |
elif self.arch_encoder == "mobilenetv2dilated": | |
self.arch_decoder = "c1_deepsup" | |
self.fc_dim = 320 | |
else: | |
raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}") | |
model_builder_kwargs = dict(arch_encoder=self.arch_encoder, | |
arch_decoder=self.arch_decoder, | |
fc_dim=self.fc_dim, | |
drop_last_conv=drop_last_conv, | |
weights_path=self.weights_path) | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc | |
self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec | |
self.use_default_normalization = use_default_normalization | |
self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
self.encode = encode | |
self.return_feature_maps = return_feature_maps | |
assert 0 <= return_feature_maps_level <= 3 | |
self.return_feature_maps_level = return_feature_maps_level | |
def normalize_input(self, tensor): | |
if tensor.min() < 0 or tensor.max() > 1: | |
raise ValueError("Tensor should be 0..1 before using normalize_input") | |
return self.default_normalization(tensor) | |
def feature_maps_channels(self): | |
return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048 | |
def forward(self, img_data, segSize=None): | |
if segSize is None: | |
raise NotImplementedError("Please pass segSize param. By default: (300, 300)") | |
fmaps = self.encoder(img_data, return_feature_maps=True) | |
pred = self.decoder(fmaps, segSize=segSize) | |
if self.return_feature_maps: | |
return pred, fmaps | |
# print("BINARY", img_data.shape, pred.shape) | |
return pred | |
def multi_mask_from_multiclass(self, pred, classes): | |
def isin(ar1, ar2): | |
return (ar1[..., None] == ar2).any(-1).float() | |
return isin(pred, torch.LongTensor(classes).to(self.device)) | |
def multi_mask_from_multiclass_probs(scores, classes): | |
res = None | |
for c in classes: | |
if res is None: | |
res = scores[:, c] | |
else: | |
res += scores[:, c] | |
return res | |
def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600) | |
segSize=None): | |
"""Entry-point for segmentation. Use this methods instead of forward | |
Arguments: | |
tensor {torch.Tensor} -- BCHW | |
Keyword Arguments: | |
imgSizes {tuple or list} -- imgSizes for segmentation input. | |
default: (300, 450) | |
original implementation: (300, 375, 450, 525, 600) | |
""" | |
if segSize is None: | |
segSize = tensor.shape[-2:] | |
segSize = (tensor.shape[2], tensor.shape[3]) | |
with torch.no_grad(): | |
if self.use_default_normalization: | |
tensor = self.normalize_input(tensor) | |
scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device) | |
features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device) | |
result = [] | |
for img_size in imgSizes: | |
if img_size != -1: | |
img_data = F.interpolate(tensor.clone(), size=img_size) | |
else: | |
img_data = tensor.clone() | |
if self.return_feature_maps: | |
pred_current, fmaps = self.forward(img_data, segSize=segSize) | |
else: | |
pred_current = self.forward(img_data, segSize=segSize) | |
result.append(pred_current) | |
scores = scores + pred_current / len(imgSizes) | |
# Disclaimer: We use and aggregate only last fmaps: fmaps[3] | |
if self.return_feature_maps: | |
features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes) | |
_, pred = torch.max(scores, dim=1) | |
if self.return_feature_maps: | |
return features | |
return pred, result | |
def get_edges(self, t): | |
edge = torch.cuda.ByteTensor(t.size()).zero_() | |
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) | |
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) | |
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) | |
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) | |
if True: | |
return edge.half() | |
return edge.float() | |
# pyramid pooling, deep supervision | |
class PPMDeepsup(nn.Module): | |
def __init__(self, num_class=NUM_CLASS, fc_dim=4096, | |
use_softmax=False, pool_scales=(1, 2, 3, 6), | |
drop_last_conv=False): | |
super().__init__() | |
self.use_softmax = use_softmax | |
self.drop_last_conv = drop_last_conv | |
self.ppm = [] | |
for scale in pool_scales: | |
self.ppm.append(nn.Sequential( | |
nn.AdaptiveAvgPool2d(scale), | |
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
BatchNorm2d(512), | |
nn.ReLU(inplace=True) | |
)) | |
self.ppm = nn.ModuleList(self.ppm) | |
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |
self.conv_last = nn.Sequential( | |
nn.Conv2d(fc_dim + len(pool_scales) * 512, 512, | |
kernel_size=3, padding=1, bias=False), | |
BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
nn.Dropout2d(0.1), | |
nn.Conv2d(512, num_class, kernel_size=1) | |
) | |
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
self.dropout_deepsup = nn.Dropout2d(0.1) | |
def forward(self, conv_out, segSize=None): | |
conv5 = conv_out[-1] | |
input_size = conv5.size() | |
ppm_out = [conv5] | |
for pool_scale in self.ppm: | |
ppm_out.append(nn.functional.interpolate( | |
pool_scale(conv5), | |
(input_size[2], input_size[3]), | |
mode='bilinear', align_corners=False)) | |
ppm_out = torch.cat(ppm_out, 1) | |
if self.drop_last_conv: | |
return ppm_out | |
else: | |
x = self.conv_last(ppm_out) | |
if self.use_softmax: # is True during inference | |
x = nn.functional.interpolate( | |
x, size=segSize, mode='bilinear', align_corners=False) | |
x = nn.functional.softmax(x, dim=1) | |
return x | |
# deep sup | |
conv4 = conv_out[-2] | |
_ = self.cbr_deepsup(conv4) | |
_ = self.dropout_deepsup(_) | |
_ = self.conv_last_deepsup(_) | |
x = nn.functional.log_softmax(x, dim=1) | |
_ = nn.functional.log_softmax(_, dim=1) | |
return (x, _) | |
class Resnet(nn.Module): | |
def __init__(self, orig_resnet): | |
super(Resnet, self).__init__() | |
# take pretrained resnet, except AvgPool and FC | |
self.conv1 = orig_resnet.conv1 | |
self.bn1 = orig_resnet.bn1 | |
self.relu1 = orig_resnet.relu1 | |
self.conv2 = orig_resnet.conv2 | |
self.bn2 = orig_resnet.bn2 | |
self.relu2 = orig_resnet.relu2 | |
self.conv3 = orig_resnet.conv3 | |
self.bn3 = orig_resnet.bn3 | |
self.relu3 = orig_resnet.relu3 | |
self.maxpool = orig_resnet.maxpool | |
self.layer1 = orig_resnet.layer1 | |
self.layer2 = orig_resnet.layer2 | |
self.layer3 = orig_resnet.layer3 | |
self.layer4 = orig_resnet.layer4 | |
def forward(self, x, return_feature_maps=False): | |
conv_out = [] | |
x = self.relu1(self.bn1(self.conv1(x))) | |
x = self.relu2(self.bn2(self.conv2(x))) | |
x = self.relu3(self.bn3(self.conv3(x))) | |
x = self.maxpool(x) | |
x = self.layer1(x); conv_out.append(x); | |
x = self.layer2(x); conv_out.append(x); | |
x = self.layer3(x); conv_out.append(x); | |
x = self.layer4(x); conv_out.append(x); | |
if return_feature_maps: | |
return conv_out | |
return [x] | |
# Resnet Dilated | |
class ResnetDilated(nn.Module): | |
def __init__(self, orig_resnet, dilate_scale=8): | |
super().__init__() | |
from functools import partial | |
if dilate_scale == 8: | |
orig_resnet.layer3.apply( | |
partial(self._nostride_dilate, dilate=2)) | |
orig_resnet.layer4.apply( | |
partial(self._nostride_dilate, dilate=4)) | |
elif dilate_scale == 16: | |
orig_resnet.layer4.apply( | |
partial(self._nostride_dilate, dilate=2)) | |
# take pretrained resnet, except AvgPool and FC | |
self.conv1 = orig_resnet.conv1 | |
self.bn1 = orig_resnet.bn1 | |
self.relu1 = orig_resnet.relu1 | |
self.conv2 = orig_resnet.conv2 | |
self.bn2 = orig_resnet.bn2 | |
self.relu2 = orig_resnet.relu2 | |
self.conv3 = orig_resnet.conv3 | |
self.bn3 = orig_resnet.bn3 | |
self.relu3 = orig_resnet.relu3 | |
self.maxpool = orig_resnet.maxpool | |
self.layer1 = orig_resnet.layer1 | |
self.layer2 = orig_resnet.layer2 | |
self.layer3 = orig_resnet.layer3 | |
self.layer4 = orig_resnet.layer4 | |
def _nostride_dilate(self, m, dilate): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
# the convolution with stride | |
if m.stride == (2, 2): | |
m.stride = (1, 1) | |
if m.kernel_size == (3, 3): | |
m.dilation = (dilate // 2, dilate // 2) | |
m.padding = (dilate // 2, dilate // 2) | |
# other convoluions | |
else: | |
if m.kernel_size == (3, 3): | |
m.dilation = (dilate, dilate) | |
m.padding = (dilate, dilate) | |
def forward(self, x, return_feature_maps=False): | |
conv_out = [] | |
x = self.relu1(self.bn1(self.conv1(x))) | |
x = self.relu2(self.bn2(self.conv2(x))) | |
x = self.relu3(self.bn3(self.conv3(x))) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
conv_out.append(x) | |
x = self.layer2(x) | |
conv_out.append(x) | |
x = self.layer3(x) | |
conv_out.append(x) | |
x = self.layer4(x) | |
conv_out.append(x) | |
if return_feature_maps: | |
return conv_out | |
return [x] | |
class MobileNetV2Dilated(nn.Module): | |
def __init__(self, orig_net, dilate_scale=8): | |
super(MobileNetV2Dilated, self).__init__() | |
from functools import partial | |
# take pretrained mobilenet features | |
self.features = orig_net.features[:-1] | |
self.total_idx = len(self.features) | |
self.down_idx = [2, 4, 7, 14] | |
if dilate_scale == 8: | |
for i in range(self.down_idx[-2], self.down_idx[-1]): | |
self.features[i].apply( | |
partial(self._nostride_dilate, dilate=2) | |
) | |
for i in range(self.down_idx[-1], self.total_idx): | |
self.features[i].apply( | |
partial(self._nostride_dilate, dilate=4) | |
) | |
elif dilate_scale == 16: | |
for i in range(self.down_idx[-1], self.total_idx): | |
self.features[i].apply( | |
partial(self._nostride_dilate, dilate=2) | |
) | |
def _nostride_dilate(self, m, dilate): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
# the convolution with stride | |
if m.stride == (2, 2): | |
m.stride = (1, 1) | |
if m.kernel_size == (3, 3): | |
m.dilation = (dilate//2, dilate//2) | |
m.padding = (dilate//2, dilate//2) | |
# other convoluions | |
else: | |
if m.kernel_size == (3, 3): | |
m.dilation = (dilate, dilate) | |
m.padding = (dilate, dilate) | |
def forward(self, x, return_feature_maps=False): | |
if return_feature_maps: | |
conv_out = [] | |
for i in range(self.total_idx): | |
x = self.features[i](x) | |
if i in self.down_idx: | |
conv_out.append(x) | |
conv_out.append(x) | |
return conv_out | |
else: | |
return [self.features(x)] | |
# last conv, deep supervision | |
class C1DeepSup(nn.Module): | |
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False): | |
super(C1DeepSup, self).__init__() | |
self.use_softmax = use_softmax | |
self.drop_last_conv = drop_last_conv | |
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |
# last conv | |
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
def forward(self, conv_out, segSize=None): | |
conv5 = conv_out[-1] | |
x = self.cbr(conv5) | |
if self.drop_last_conv: | |
return x | |
else: | |
x = self.conv_last(x) | |
if self.use_softmax: # is True during inference | |
x = nn.functional.interpolate( | |
x, size=segSize, mode='bilinear', align_corners=False) | |
x = nn.functional.softmax(x, dim=1) | |
return x | |
# deep sup | |
conv4 = conv_out[-2] | |
_ = self.cbr_deepsup(conv4) | |
_ = self.conv_last_deepsup(_) | |
x = nn.functional.log_softmax(x, dim=1) | |
_ = nn.functional.log_softmax(_, dim=1) | |
return (x, _) | |
# last conv | |
class C1(nn.Module): | |
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): | |
super(C1, self).__init__() | |
self.use_softmax = use_softmax | |
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |
# last conv | |
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
def forward(self, conv_out, segSize=None): | |
conv5 = conv_out[-1] | |
x = self.cbr(conv5) | |
x = self.conv_last(x) | |
if self.use_softmax: # is True during inference | |
x = nn.functional.interpolate( | |
x, size=segSize, mode='bilinear', align_corners=False) | |
x = nn.functional.softmax(x, dim=1) | |
else: | |
x = nn.functional.log_softmax(x, dim=1) | |
return x | |
# pyramid pooling | |
class PPM(nn.Module): | |
def __init__(self, num_class=150, fc_dim=4096, | |
use_softmax=False, pool_scales=(1, 2, 3, 6)): | |
super(PPM, self).__init__() | |
self.use_softmax = use_softmax | |
self.ppm = [] | |
for scale in pool_scales: | |
self.ppm.append(nn.Sequential( | |
nn.AdaptiveAvgPool2d(scale), | |
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
BatchNorm2d(512), | |
nn.ReLU(inplace=True) | |
)) | |
self.ppm = nn.ModuleList(self.ppm) | |
self.conv_last = nn.Sequential( | |
nn.Conv2d(fc_dim+len(pool_scales)*512, 512, | |
kernel_size=3, padding=1, bias=False), | |
BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
nn.Dropout2d(0.1), | |
nn.Conv2d(512, num_class, kernel_size=1) | |
) | |
def forward(self, conv_out, segSize=None): | |
conv5 = conv_out[-1] | |
input_size = conv5.size() | |
ppm_out = [conv5] | |
for pool_scale in self.ppm: | |
ppm_out.append(nn.functional.interpolate( | |
pool_scale(conv5), | |
(input_size[2], input_size[3]), | |
mode='bilinear', align_corners=False)) | |
ppm_out = torch.cat(ppm_out, 1) | |
x = self.conv_last(ppm_out) | |
if self.use_softmax: # is True during inference | |
x = nn.functional.interpolate( | |
x, size=segSize, mode='bilinear', align_corners=False) | |
x = nn.functional.softmax(x, dim=1) | |
else: | |
x = nn.functional.log_softmax(x, dim=1) | |
return x | |