import torch from torch import nn from torch.nn import functional as F import torch.nn.utils.spectral_norm as SpectralNorm import random from .helper_arch import ResTextBlockV2, adaptive_instance_normalization class SRNet(nn.Module): def __init__(self, in_channel=3, dim_channel=256): super().__init__() self.conv_first_32 = nn.Sequential( SpectralNorm(nn.Conv2d(in_channel, dim_channel//4, 3, 1, 1)), nn.LeakyReLU(0.2), ) self.conv_first_16 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel//4, dim_channel//2, 3, 2, 1)), nn.LeakyReLU(0.2), ) self.conv_first_8 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel//2, dim_channel, 3, 2, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_body_16 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel+dim_channel//2, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_body_32 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel+dim_channel//4, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_up = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), #64*64*256 SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), ResTextBlockV2(dim_channel, dim_channel), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_final = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel//2, 3, 1, 1)), nn.LeakyReLU(0.2), nn.Upsample(scale_factor=2, mode='bilinear'), #128*128*256 SpectralNorm(nn.Conv2d(dim_channel//2, dim_channel//4, 3, 1, 1)), nn.LeakyReLU(0.2), ResTextBlockV2(dim_channel//4, dim_channel//4), SpectralNorm(nn.Conv2d(dim_channel//4, 3, 3, 1, 1)), nn.Tanh() ) # self.conv_priorout = nn.Sequential( # SpectralNorm(nn.Conv2d(dim_channel, dim_channel//2, 3, 1, 1)), # nn.LeakyReLU(0.2), # nn.Upsample(scale_factor=2, mode='bilinear'), #128*128*256 # SpectralNorm(nn.Conv2d(dim_channel//2, dim_channel//4, 3, 1, 1)), # nn.LeakyReLU(0.2), # ResTextBlockV2(dim_channel//4, dim_channel//4), # SpectralNorm(nn.Conv2d(dim_channel//4, 3, 3, 1, 1)), # nn.Tanh() # ) self.conv_32_scale = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_32_shift = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_32_fuse = nn.Sequential( ResTextBlockV2(2*dim_channel, dim_channel) ) self.conv_32_to256 = nn.Sequential( SpectralNorm(nn.Conv2d(512, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_64_scale = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_64_shift = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_64_fuse = nn.Sequential( ResTextBlockV2(2*dim_channel, dim_channel) ) def forward(self, lq, priors64, priors32, locs): # # lq_features:b*512*8*512 # priors: 8, 16,32,64,128 # locs: b*32, center+width for 128*2048 0~1 # locs: b*16, center for 128*2048, 0~2048 single_sr = True lq_f_32 = self.conv_first_32(lq) lq_f_16 = self.conv_first_16(lq_f_32) lq_f_8 = self.conv_first_8(lq_f_16) sq_f_16 = self.conv_body_16(torch.cat([F.interpolate(lq_f_8, scale_factor=2, mode='bilinear'), lq_f_16], dim=1)) sq_f_32 = self.conv_body_32(torch.cat([F.interpolate(sq_f_16, scale_factor=2, mode='bilinear'), lq_f_32], dim=1)) # 256*32*32 if priors32 is not None: sq_f_32_ori = sq_f_32.clone() # sq_f_32_res = sq_f_32.clone().detach()*0 prior_32_align = torch.zeros_like(sq_f_32_ori) prior_32_mask = torch.zeros((sq_f_32_ori.size(0), 1, sq_f_32_ori.size(2), sq_f_32_ori.size(3)), dtype=sq_f_32_ori.dtype, layout=sq_f_32_ori.layout, device=sq_f_32_ori.device) for b, p_32 in enumerate(priors32): #512*32*32, different batch p_32_256 = self.conv_32_to256(p_32.clone()) for c in range(p_32_256.size(0)): # center = (locs[b][c].detach()/4.0).int() # width = 16 if center < width: x1 = 0 #lq feature left y1 = max(16 - center, 0) else: x1 = center - width y1 = max(16 - width, 0) # y1 = 16 - width if center + width > sq_f_32.size(-1): x2 = sq_f_32.size(-1) #lq feature right else: x2 = center + width y2 = y1 + (x2 - x1) ''' center align ''' # y1 = 16 - torch.div(x2-x1, 2, rounding_mode='trunc') y2 = y1 + x2 - x1 if single_sr: char_prior_f = p_32_256[c:c+1, :, :, y1:y2].clone() #prior char_lq_f = sq_f_32[b:b+1, :, :, x1:x2].clone() adain_prior_f = adaptive_instance_normalization(char_prior_f, char_lq_f) fuse_32_prior = self.conv_32_fuse(torch.cat((adain_prior_f, char_lq_f), dim=1)) scale = self.conv_32_scale(fuse_32_prior) shift = self.conv_32_shift(fuse_32_prior) prior_32_align[b, :, :, x1:x2] = prior_32_align[b, :, :, x1:x2] + sq_f_32[b, :, :, x1:x2].clone() * scale[0,...] + shift[0,...] else: prior_32_align[b, :, :, x1:x2] = prior_32_align[b, :, :, x1:x2] + p_32_256[c:c+1, :, :, y1:y2].clone() # prior_32_mask[b, :, :, x1:x2] += 1.0 # prior_32_mask[prior_32_mask<2]=1.0 # prior_32_align = prior_32_align / prior_32_mask.repeat(1, prior_32_align.size(1), 1, 1) if single_sr: sq_pf_32_out = sq_f_32_ori + prior_32_align else: sq_f_32_norm = adaptive_instance_normalization(prior_32_align, sq_f_32) sq_f_32_fuse = self.conv_32_fuse(torch.cat((sq_f_32_norm, sq_f_32), dim=1)) scale_32 = self.conv_32_scale(sq_f_32_fuse) shift_32 = self.conv_32_shift(sq_f_32_fuse) sq_f_32_res = sq_f_32_ori * scale_32 + shift_32 sq_pf_32_out = sq_f_32_ori + sq_f_32_res else: sq_pf_32_out = sq_f_32.clone() sq_f_64 = self.conv_up(sq_pf_32_out) #64*1024 sq_f_64_ori = sq_f_64.clone() prior_64_align = torch.zeros_like(sq_f_64_ori) prior_64_mask = torch.zeros((sq_f_64_ori.size(0), 1, sq_f_64_ori.size(2), sq_f_64_ori.size(3)), dtype=sq_f_64_ori.dtype, layout=sq_f_64_ori.layout, device=sq_f_64_ori.device) for b, p_64_prior in enumerate(priors64): #512*8*8, 512*16*16, 512*32*32, 256*64*64, 128*128*128 different batch p_64 = p_64_prior.clone() #.detach() #no backward to prior for c in range(p_64.size(0)): # for each character center = (locs[b][c].detach()/2.0).int() #+ random.randint(-4,4)### no backward width = 32 if center < width: x1 = 0 y1 = max(32 - center, 0) else: x1 = center -width y1 = max(32 - width, 0) if center + width > sq_f_64.size(-1): x2 = sq_f_64.size(-1) else: x2 = center + width ''' center align ''' # y1 = 32 - torch.div(x2-x1, 2, rounding_mode='trunc') y2 = y1 + x2 - x1 if single_sr: char_prior_f = p_64[c:c+1, :, :, y1:y2].clone() char_lq_f = sq_f_64[b:b+1, :, :, x1:x2].clone() adain_prior_f = adaptive_instance_normalization(char_prior_f, char_lq_f) fuse_64_prior = self.conv_64_fuse(torch.cat((adain_prior_f, char_lq_f), dim=1)) scale = self.conv_64_scale(fuse_64_prior) shift = self.conv_64_shift(fuse_64_prior) prior_64_align[b, :, :, x1:x2] = prior_64_align[b, :, :, x1:x2] + sq_f_64[b, :, :, x1:x2].clone() * scale[0,...] + shift[0,...] else: prior_64_align[b, :, :, x1:x2] = prior_64_align[b, :, :, x1:x2] + p_64[c:c+1, :, :, y1:y2].clone() # prior_64_mask[b, :, :, x1:x2] += 1.0 # prior_64_mask[prior_64_mask<2]=1.0 # prior_64_align = prior_64_align / prior_64_mask.repeat(1, prior_64_align.size(1), 1, 1) if single_sr: sq_pf_64 = sq_f_64_ori + prior_64_align else: sq_f_64_norm = adaptive_instance_normalization(prior_64_align, sq_f_64_ori) sq_f_64_fuse = self.conv_64_fuse(torch.cat((sq_f_64_norm, sq_f_64_ori), dim=1)) scale_64 = self.conv_64_scale(sq_f_64_fuse) shift_64 = self.conv_64_shift(sq_f_64_fuse) sq_f_64_res = sq_f_64_ori * scale_64 + shift_64 sq_pf_64 = sq_f_64_ori + sq_f_64_res f256 = self.conv_final(sq_pf_64) # adain_lr2prior = adaptive_instance_normalization(prior_full_64, F.interpolate(sq_f_32_ori, scale_factor=2, mode='bilinear')) # prior_out = self.conv_priorout(adain_lr2prior) return f256 #prior_out