Spaces:
Sleeping
Sleeping
# Copyright (C) 2021 * Ltd. All rights reserved. | |
# author : Sanghyeon Jo <josanghyeokn@gmail.com> | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import models | |
import torch.utils.model_zoo as model_zoo | |
from .arch_resnet import resnet | |
from .arch_resnest import resnest | |
from .abc_modules import ABC_Model | |
from .deeplab_utils import ASPP, Decoder | |
from .aff_utils import PathIndex | |
from .puzzle_utils import tile_features, merge_features | |
from tools.ai.torch_utils import resize_for_tensors | |
####################################################################### | |
# Normalization | |
####################################################################### | |
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
class FixedBatchNorm(nn.BatchNorm2d): | |
def forward(self, x): | |
return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps) | |
def group_norm(features): | |
return nn.GroupNorm(4, features) | |
####################################################################### | |
class Backbone(nn.Module, ABC_Model): | |
def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False): | |
super().__init__() | |
self.mode = mode | |
if self.mode == 'fix': | |
self.norm_fn = FixedBatchNorm | |
else: | |
self.norm_fn = nn.BatchNorm2d | |
if 'resnet' in model_name: | |
self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1), | |
batch_norm_fn=self.norm_fn) | |
state_dict = model_zoo.load_url(resnet.urls_dic[model_name]) | |
state_dict.pop('fc.weight') | |
state_dict.pop('fc.bias') | |
self.model.load_state_dict(state_dict) | |
else: | |
if segmentation: | |
dilation, dilated = 4, True | |
else: | |
dilation, dilated = 2, False | |
self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation, | |
norm_layer=self.norm_fn) | |
del self.model.avgpool | |
del self.model.fc | |
self.stage1 = nn.Sequential(self.model.conv1, | |
self.model.bn1, | |
self.model.relu, | |
self.model.maxpool) | |
self.stage2 = nn.Sequential(self.model.layer1) | |
self.stage3 = nn.Sequential(self.model.layer2) | |
self.stage4 = nn.Sequential(self.model.layer3) | |
self.stage5 = nn.Sequential(self.model.layer4) | |
class Classifier(Backbone): | |
def __init__(self, model_name, state_path, num_classes=20, mode='fix'): | |
super().__init__(model_name, state_path, num_classes, mode) | |
self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False) | |
self.num_classes = num_classes | |
self.initialize([self.classifier]) | |
def forward(self, x, with_cam=False): | |
x = self.stage1(x) | |
x = self.stage2(x) | |
x = self.stage3(x) | |
x = self.stage4(x) | |
x = self.stage5(x) | |
if with_cam: | |
features = self.classifier(x) | |
logits = self.global_average_pooling_2d(features) | |
return logits, features | |
else: | |
x = self.global_average_pooling_2d(x, keepdims=True) | |
logits = self.classifier(x).view(-1, self.num_classes) | |
return logits | |
class Classifier_For_Positive_Pooling(Backbone): | |
def __init__(self, model_name, num_classes=20, mode='fix'): | |
super().__init__(model_name, num_classes, mode) | |
self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False) | |
self.num_classes = num_classes | |
self.initialize([self.classifier]) | |
def forward(self, x, with_cam=False): | |
x = self.stage1(x) | |
x = self.stage2(x) | |
x = self.stage3(x) | |
x = self.stage4(x) | |
x = self.stage5(x) | |
if with_cam: | |
features = self.classifier(x) | |
logits = self.global_average_pooling_2d(features) | |
return logits, features | |
else: | |
x = self.global_average_pooling_2d(x, keepdims=True) | |
logits = self.classifier(x).view(-1, self.num_classes) | |
return logits | |
class Classifier_For_Puzzle(Classifier): | |
def __init__(self, model_name, num_classes=20, mode='fix'): | |
super().__init__(model_name, num_classes, mode) | |
def forward(self, x, num_pieces=1, level=-1): | |
batch_size = x.size()[0] | |
output_dic = {} | |
layers = [self.stage1, self.stage2, self.stage3, self.stage4, self.stage5, self.classifier] | |
for l, layer in enumerate(layers): | |
l += 1 | |
if level == l: | |
x = tile_features(x, num_pieces) | |
x = layer(x) | |
output_dic['stage%d'%l] = x | |
output_dic['logits'] = self.global_average_pooling_2d(output_dic['stage6']) | |
for l in range(len(layers)): | |
l += 1 | |
if l >= level: | |
output_dic['stage%d'%l] = merge_features(output_dic['stage%d'%l], num_pieces, batch_size) | |
if level is not None: | |
output_dic['merged_logits'] = self.global_average_pooling_2d(output_dic['stage6']) | |
return output_dic | |
class AffinityNet(Backbone): | |
def __init__(self, model_name, path_index=None): | |
super().__init__(model_name, None, 'fix') | |
if '50' in model_name: | |
fc_edge1_features = 64 | |
else: | |
fc_edge1_features = 128 | |
self.fc_edge1 = nn.Sequential( | |
nn.Conv2d(fc_edge1_features, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge2 = nn.Sequential( | |
nn.Conv2d(256, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge3 = nn.Sequential( | |
nn.Conv2d(512, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge4 = nn.Sequential( | |
nn.Conv2d(1024, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge5 = nn.Sequential( | |
nn.Conv2d(2048, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge6 = nn.Conv2d(160, 1, 1, bias=True) | |
self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4, self.stage5]) | |
self.edge_layers = nn.ModuleList([self.fc_edge1, self.fc_edge2, self.fc_edge3, self.fc_edge4, self.fc_edge5, self.fc_edge6]) | |
if path_index is not None: | |
self.path_index = path_index | |
self.n_path_lengths = len(self.path_index.path_indices) | |
for i, pi in enumerate(self.path_index.path_indices): | |
self.register_buffer("path_indices_" + str(i), torch.from_numpy(pi)) | |
def train(self, mode=True): | |
super().train(mode) | |
self.backbone.eval() | |
def forward(self, x, with_affinity=False): | |
x1 = self.stage1(x).detach() | |
x2 = self.stage2(x1).detach() | |
x3 = self.stage3(x2).detach() | |
x4 = self.stage4(x3).detach() | |
x5 = self.stage5(x4).detach() | |
edge1 = self.fc_edge1(x1) | |
edge2 = self.fc_edge2(x2) | |
edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)] | |
edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)] | |
edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)] | |
edge = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1)) | |
if with_affinity: | |
return edge, self.to_affinity(torch.sigmoid(edge)) | |
else: | |
return edge | |
def get_edge(self, x, image_size=512, stride=4): | |
feat_size = (x.size(2)-1)//stride+1, (x.size(3)-1)//stride+1 | |
x = F.pad(x, [0, image_size-x.size(3), 0, image_size-x.size(2)]) | |
edge_out = self.forward(x) | |
edge_out = edge_out[..., :feat_size[0], :feat_size[1]] | |
edge_out = torch.sigmoid(edge_out[0]/2 + edge_out[1].flip(-1)/2) | |
return edge_out | |
""" | |
aff = self.to_affinity(torch.sigmoid(edge_out)) | |
pos_aff_loss = (-1) * torch.log(aff + 1e-5) | |
neg_aff_loss = (-1) * torch.log(1. + 1e-5 - aff) | |
""" | |
def to_affinity(self, edge): | |
aff_list = [] | |
edge = edge.view(edge.size(0), -1) | |
for i in range(self.n_path_lengths): | |
ind = self._buffers["path_indices_" + str(i)] | |
ind_flat = ind.view(-1) | |
dist = torch.index_select(edge, dim=-1, index=ind_flat) | |
dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) | |
aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) | |
aff_list.append(aff) | |
aff_cat = torch.cat(aff_list, dim=1) | |
return aff_cat | |
class DeepLabv3_Plus(Backbone): | |
def __init__(self, model_name, num_classes=21, mode='fix', use_group_norm=False): | |
super().__init__(model_name, num_classes, mode, segmentation=False) | |
if use_group_norm: | |
norm_fn_for_extra_modules = group_norm | |
else: | |
norm_fn_for_extra_modules = self.norm_fn | |
self.aspp = ASPP(output_stride=16, norm_fn=norm_fn_for_extra_modules) | |
self.decoder = Decoder(num_classes, 256, norm_fn_for_extra_modules) | |
def forward(self, x, with_cam=False): | |
inputs = x | |
x = self.stage1(x) | |
x = self.stage2(x) | |
x_low_level = x | |
x = self.stage3(x) | |
x = self.stage4(x) | |
x = self.stage5(x) | |
x = self.aspp(x) | |
x = self.decoder(x, x_low_level) | |
x = resize_for_tensors(x, inputs.size()[2:], align_corners=True) | |
return x | |
class Seg_Model(Backbone): | |
def __init__(self, model_name, num_classes=21): | |
super().__init__(model_name, num_classes, mode='fix', segmentation=False) | |
self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False) | |
def forward(self, inputs): | |
x = self.stage1(inputs) | |
x = self.stage2(x) | |
x = self.stage3(x) | |
x = self.stage4(x) | |
x = self.stage5(x) | |
logits = self.classifier(x) | |
# logits = resize_for_tensors(logits, inputs.size()[2:], align_corners=False) | |
return logits | |
class CSeg_Model(Backbone): | |
def __init__(self, model_name, num_classes=21): | |
super().__init__(model_name, num_classes, 'fix') | |
if '50' in model_name: | |
fc_edge1_features = 64 | |
else: | |
fc_edge1_features = 128 | |
self.fc_edge1 = nn.Sequential( | |
nn.Conv2d(fc_edge1_features, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge2 = nn.Sequential( | |
nn.Conv2d(256, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge3 = nn.Sequential( | |
nn.Conv2d(512, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge4 = nn.Sequential( | |
nn.Conv2d(1024, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge5 = nn.Sequential( | |
nn.Conv2d(2048, 32, 1, bias=False), | |
nn.GroupNorm(4, 32), | |
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), | |
nn.ReLU(inplace=True), | |
) | |
self.fc_edge6 = nn.Conv2d(160, num_classes, 1, bias=True) | |
def forward(self, x): | |
x1 = self.stage1(x) | |
x2 = self.stage2(x1) | |
x3 = self.stage3(x2) | |
x4 = self.stage4(x3) | |
x5 = self.stage5(x4) | |
edge1 = self.fc_edge1(x1) | |
edge2 = self.fc_edge2(x2) | |
edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)] | |
edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)] | |
edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)] | |
logits = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1)) | |
# logits = resize_for_tensors(logits, x.size()[2:], align_corners=True) | |
return logits | |