|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
from pathlib import Path |
|
from .miniViT import mViT |
|
from modules.shared import opts |
|
|
|
class UpSampleBN(nn.Module): |
|
def __init__(self, skip_input, output_features): |
|
super(UpSampleBN, self).__init__() |
|
|
|
self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(output_features), |
|
nn.LeakyReLU(), |
|
nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(output_features), |
|
nn.LeakyReLU()) |
|
|
|
def forward(self, x, concat_with): |
|
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) |
|
f = torch.cat([up_x, concat_with], dim=1) |
|
return self._net(f) |
|
|
|
|
|
class DecoderBN(nn.Module): |
|
def __init__(self, num_features=2048, num_classes=1, bottleneck_features=2048): |
|
super(DecoderBN, self).__init__() |
|
features = int(num_features) |
|
|
|
self.conv2 = nn.Conv2d(bottleneck_features, features, kernel_size=1, stride=1, padding=1) |
|
|
|
self.up1 = UpSampleBN(skip_input=features // 1 + 112 + 64, output_features=features // 2) |
|
self.up2 = UpSampleBN(skip_input=features // 2 + 40 + 24, output_features=features // 4) |
|
self.up3 = UpSampleBN(skip_input=features // 4 + 24 + 16, output_features=features // 8) |
|
self.up4 = UpSampleBN(skip_input=features // 8 + 16 + 8, output_features=features // 16) |
|
|
|
|
|
self.conv3 = nn.Conv2d(features // 16, num_classes, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
def forward(self, features): |
|
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[ |
|
11] |
|
|
|
x_d0 = self.conv2(x_block4) |
|
|
|
x_d1 = self.up1(x_d0, x_block3) |
|
x_d2 = self.up2(x_d1, x_block2) |
|
x_d3 = self.up3(x_d2, x_block1) |
|
x_d4 = self.up4(x_d3, x_block0) |
|
|
|
out = self.conv3(x_d4) |
|
|
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, backend): |
|
super(Encoder, self).__init__() |
|
self.original_model = backend |
|
|
|
def forward(self, x): |
|
features = [x] |
|
for k, v in self.original_model._modules.items(): |
|
if (k == 'blocks'): |
|
for ki, vi in v._modules.items(): |
|
features.append(vi(features[-1])) |
|
else: |
|
features.append(v(features[-1])) |
|
return features |
|
|
|
|
|
class UnetAdaptiveBins(nn.Module): |
|
def __init__(self, backend, n_bins=100, min_val=0.1, max_val=10, norm='linear'): |
|
super(UnetAdaptiveBins, self).__init__() |
|
self.num_classes = n_bins |
|
self.min_val = min_val |
|
self.max_val = max_val |
|
self.encoder = Encoder(backend) |
|
self.adaptive_bins_layer = mViT(128, n_query_channels=128, patch_size=16, |
|
dim_out=n_bins, |
|
embedding_dim=128, norm=norm) |
|
|
|
self.decoder = DecoderBN(num_classes=128) |
|
self.conv_out = nn.Sequential(nn.Conv2d(128, n_bins, kernel_size=1, stride=1, padding=0), |
|
nn.Softmax(dim=1)) |
|
|
|
def forward(self, x, **kwargs): |
|
unet_out = self.decoder(self.encoder(x), **kwargs) |
|
bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(unet_out) |
|
out = self.conv_out(range_attention_maps) |
|
|
|
|
|
|
|
|
|
|
|
bin_widths = (self.max_val - self.min_val) * bin_widths_normed |
|
bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_val) |
|
bin_edges = torch.cumsum(bin_widths, dim=1) |
|
|
|
centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:]) |
|
n, dout = centers.size() |
|
centers = centers.view(n, dout, 1, 1) |
|
|
|
pred = torch.sum(out * centers, dim=1, keepdim=True) |
|
|
|
return bin_edges, pred |
|
|
|
def get_1x_lr_params(self): |
|
return self.encoder.parameters() |
|
|
|
def get_10x_lr_params(self): |
|
modules = [self.decoder, self.adaptive_bins_layer, self.conv_out] |
|
for m in modules: |
|
yield from m.parameters() |
|
|
|
@classmethod |
|
def build(cls, n_bins, **kwargs): |
|
DEBUG_MODE = opts.data.get("deforum_debug_mode_enabled", False) |
|
basemodel_name = 'tf_efficientnet_b5_ap' |
|
|
|
print('Loading AdaBins model...') |
|
predicted_torch_model_cache_path = str(Path.home()) + '\\.cache\\torch\\hub\\rwightman_gen-efficientnet-pytorch_master' |
|
predicted_gep_cache_testilfe = Path(predicted_torch_model_cache_path + '\\hubconf.py') |
|
|
|
|
|
if os.path.isfile(predicted_gep_cache_testilfe): |
|
basemodel = torch.hub.load(predicted_torch_model_cache_path, basemodel_name, pretrained=True, source = 'local') |
|
else: |
|
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) |
|
if DEBUG_MODE: |
|
print('Done.') |
|
|
|
|
|
if DEBUG_MODE: |
|
print('Removing last two layers (global_pool & classifier).') |
|
basemodel.global_pool = nn.Identity() |
|
basemodel.classifier = nn.Identity() |
|
|
|
|
|
if DEBUG_MODE: |
|
print('Building Encoder-Decoder model..', end='') |
|
m = cls(basemodel, n_bins=n_bins, **kwargs) |
|
if DEBUG_MODE: |
|
print('Done.') |
|
return m |
|
|
|
|
|
if __name__ == '__main__': |
|
model = UnetAdaptiveBins.build(100) |
|
x = torch.rand(2, 3, 480, 640) |
|
bins, pred = model(x) |
|
print(bins.shape, pred.shape) |
|
|