|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
|
|
import models |
|
from models import register |
|
from utils import make_coord, to_coordinates |
|
|
|
from mmcv.cnn import ConvModule |
|
from .blocks.CSPLayer import CSPLayer |
|
|
|
|
|
@register('funsr') |
|
class FUNSR(nn.Module): |
|
def __init__(self, |
|
encoder_spec, |
|
has_multiscale=False, |
|
neck=None, |
|
decoder=None, |
|
global_decoder=None, |
|
encoder_rgb=False, |
|
n_forward_times=1, |
|
encode_hr_coord=False, |
|
has_bn=True, |
|
encode_scale_ratio=False, |
|
local_unfold=False, |
|
weight_gen_func='nearest-exact', |
|
return_featmap=False, |
|
): |
|
super().__init__() |
|
self.weight_gen_func = weight_gen_func |
|
self.encoder = models.make(encoder_spec) |
|
self.encoder_out_dim = self.encoder.out_dim |
|
self.encode_scale_ratio = encode_scale_ratio |
|
self.has_multiscale = has_multiscale |
|
self.encoder_rgb = encoder_rgb |
|
self.encode_hr_coord = encode_hr_coord |
|
self.local_unfold = local_unfold |
|
self.return_featmap = return_featmap |
|
|
|
self.multiscale_layers = nn.ModuleList() |
|
|
|
if self.has_multiscale: |
|
|
|
conv_cfg = None |
|
if has_bn: |
|
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) |
|
else: |
|
norm_cfg = None |
|
act_cfg = dict(type='ReLU') |
|
num_blocks = [2, 4, 6] |
|
for n_idx in range(3): |
|
conv_layer = ConvModule( |
|
self.encoder_out_dim, |
|
self.encoder_out_dim*2, |
|
3, |
|
stride=2, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg |
|
) |
|
csp_layer = CSPLayer( |
|
self.encoder_out_dim*2, |
|
self.encoder_out_dim, |
|
num_blocks=num_blocks[n_idx], |
|
add_identity=True, |
|
use_depthwise=False, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.multiscale_layers.append(nn.Sequential(conv_layer, csp_layer)) |
|
|
|
if neck is not None: |
|
self.neck = models.make(neck, args={'in_dim': self.encoder_out_dim}) |
|
modulation_dim = self.neck.d_dim |
|
else: |
|
modulation_dim = self.encoder_out_dim |
|
|
|
self.n_forward_times = n_forward_times |
|
|
|
decoder_in_dim = 2 |
|
if self.encode_scale_ratio: |
|
decoder_in_dim += 2 |
|
if self.encode_hr_coord: |
|
decoder_in_dim += 2 |
|
if self.encoder_rgb: |
|
decoder_in_dim += 3 |
|
|
|
if decoder is not None: |
|
if self.local_unfold: |
|
self.down_dim_layer = nn.Conv2d(modulation_dim * 9, modulation_dim, 1) |
|
self.decoder = models.make(decoder, args={'modulation_dim': modulation_dim, 'in_dim': decoder_in_dim}) |
|
|
|
if global_decoder is not None: |
|
decoder_in_dim = 2 |
|
if self.encode_scale_ratio: |
|
decoder_in_dim += 2 |
|
if self.encoder_rgb: |
|
decoder_in_dim += 3 |
|
|
|
self.decoder_is_proj = global_decoder.get('is_proj', False) |
|
|
|
self.global_decoder = models.make(global_decoder, args={'modulation_dim': modulation_dim, 'in_dim': decoder_in_dim}) |
|
|
|
if self.decoder_is_proj: |
|
self.input_proj = nn.Linear(modulation_dim, modulation_dim) |
|
|
|
self.output_proj = nn.Conv2d(6, 3, kernel_size=1) |
|
|
|
def forward_step(self, |
|
lr_img, |
|
func_map, |
|
global_func, |
|
rel_coord, |
|
lr_coord, |
|
hr_coord, |
|
scale_ratio_map=None, |
|
pred_rgb_value=None |
|
): |
|
|
|
if self.local_unfold: |
|
b, c, h, w = func_map.shape |
|
func_map = F.unfold(func_map, 3, padding=1).view(b, c * 9, h, w) |
|
func_map = self.down_dim_layer(func_map) |
|
local_func_map = F.interpolate(func_map, size=hr_coord.shape[-2:], mode=self.weight_gen_func) |
|
|
|
rel_coord = repeat(rel_coord, 'b c h w -> (B b) c h w', B=lr_img.size(0)) |
|
hr_coord = repeat(hr_coord, 'c h w -> B c h w', B=lr_img.size(0)) |
|
local_input = rel_coord |
|
if self.encode_scale_ratio: |
|
local_input = torch.cat([local_input, scale_ratio_map], dim=1) |
|
if self.encode_hr_coord: |
|
local_input = torch.cat([local_input, hr_coord], dim=1) |
|
if self.encoder_rgb: |
|
if pred_rgb_value is None: |
|
pred_rgb_value = F.interpolate(lr_img, size=hr_coord.shape[-2:], mode='bicubic', align_corners=True) |
|
local_input = torch.cat((local_input, pred_rgb_value), dim=1) |
|
|
|
decoder_output = self.decoder(local_input, local_func_map) |
|
|
|
if hasattr(self, 'global_decoder'): |
|
if self.decoder_is_proj: |
|
global_func = self.input_proj(global_func) |
|
global_func = repeat(global_func, 'B C -> B C H W', H=hr_coord.shape[2], W=hr_coord.shape[3]) |
|
|
|
global_input = hr_coord |
|
if self.encode_scale_ratio: |
|
global_input = torch.cat([global_input, scale_ratio_map], dim=1) |
|
if self.encoder_rgb: |
|
if pred_rgb_value is None: |
|
pred_rgb_value = F.interpolate(lr_img, size=hr_coord.shape[-2:], mode='bicubic', |
|
align_corners=True) |
|
global_input = torch.cat((global_input, pred_rgb_value), dim=1) |
|
global_decoder_output = self.global_decoder(global_input, global_func) |
|
|
|
returned_featmap = None |
|
if self.decoder_is_proj: |
|
if self.return_featmap: |
|
returned_featmap = torch.cat((global_decoder_output, decoder_output), dim=1) |
|
decoder_output = self.output_proj(torch.cat((global_decoder_output, decoder_output), dim=1)) |
|
else: |
|
decoder_output = global_decoder_output + decoder_output |
|
|
|
return decoder_output, returned_featmap |
|
|
|
def forward_backbone(self, x, keep_ori_feat=True): |
|
|
|
x = self.encoder(x) |
|
output_feats = [] |
|
if keep_ori_feat: |
|
output_feats.append(x) |
|
for layer in self.multiscale_layers: |
|
x = layer(x) |
|
output_feats.append(x) |
|
return output_feats |
|
|
|
def get_coordinate_map(self, x, hr_size): |
|
B, C, H, W = x.shape |
|
H_up, W_up = hr_size |
|
x_coord = to_coordinates(x.shape[-2:], return_map=True).to(x.device).permute(2, 0, 1) |
|
hr_coord = to_coordinates(hr_size, return_map=True).to(x.device).permute(2, 0, 1) |
|
|
|
|
|
|
|
rel_grid = hr_coord - F.interpolate(x_coord.unsqueeze(0), size=hr_size, mode='nearest-exact') |
|
rel_grid[:, 0, :, :] *= H |
|
rel_grid[:, 1, :, :] *= W |
|
|
|
return rel_grid.contiguous().detach(), x_coord.contiguous().detach(), hr_coord.contiguous().detach() |
|
|
|
def forward(self, x, out_size): |
|
B, C, H_lr, W_lr = x.shape |
|
output_feats = self.forward_backbone(x) |
|
if hasattr(self, 'neck'): |
|
global_content, func_map = self.neck(output_feats) |
|
else: |
|
global_content = None |
|
func_map = output_feats[0] |
|
rel_coord, lr_coord, hr_coord = self.get_coordinate_map(x, out_size) |
|
scale_ratio_map = None |
|
if self.encode_scale_ratio: |
|
h_ratio = x.shape[2] / out_size[0] |
|
w_ratio = x.shape[3] / out_size[1] |
|
scale_ratio_map = torch.tensor([h_ratio, w_ratio]).view(1, -1, 1, 1).expand(B, -1, *out_size).to(x.device) |
|
|
|
pred_rgb_value = None |
|
return_pred_rgb_value = [] |
|
|
|
for n_time in range(self.n_forward_times): |
|
pred_rgb_value, returned_featmaps = self.forward_step( |
|
x, |
|
func_map, |
|
global_content, |
|
rel_coord, |
|
lr_coord, |
|
hr_coord, |
|
scale_ratio_map, |
|
pred_rgb_value |
|
) |
|
return_pred_rgb_value.append(pred_rgb_value) |
|
if self.return_featmap: |
|
return return_pred_rgb_value, returned_featmaps |
|
return return_pred_rgb_value |
|
|
|
|
|
|