import torch from torch import nn from torch.nn import functional as F import torch.nn.utils.spectral_norm as SpectralNorm import random from .helper_arch import ResTextBlockV2, adaptive_instance_normalization class SRNet(nn.Module): def __init__(self, in_channel=3, dim_channel=256): super().__init__() self.conv_first_32 = nn.Sequential( SpectralNorm(nn.Conv2d(in_channel, dim_channel//4, 3, 1, 1)), nn.LeakyReLU(0.2), ) self.conv_first_16 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel//4, dim_channel//2, 3, 2, 1)), nn.LeakyReLU(0.2), ) self.conv_first_8 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel//2, dim_channel, 3, 2, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_body_16 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel+dim_channel//2, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_body_32 = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel+dim_channel//4, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_up = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), #64*64*256 SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), ResTextBlockV2(dim_channel, dim_channel), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_final = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel//2, 3, 1, 1)), nn.LeakyReLU(0.2), nn.Upsample(scale_factor=2, mode='bilinear'), #128*128*256 SpectralNorm(nn.Conv2d(dim_channel//2, dim_channel//4, 3, 1, 1)), nn.LeakyReLU(0.2), ResTextBlockV2(dim_channel//4, dim_channel//4), SpectralNorm(nn.Conv2d(dim_channel//4, 3, 3, 1, 1)), nn.Tanh() ) self.conv_32_scale = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_32_shift = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_32_fuse = nn.Sequential( ResTextBlockV2(2*dim_channel, dim_channel) ) self.conv_32_to256 = nn.Sequential( SpectralNorm(nn.Conv2d(512, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_64_scale = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_64_shift = nn.Sequential( SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), nn.LeakyReLU(0.2), SpectralNorm(nn.Conv2d(dim_channel, dim_channel, 3, 1, 1)), ) self.conv_64_fuse = nn.Sequential( ResTextBlockV2(2*dim_channel, dim_channel) ) def forward(self, lq, priors64, priors32, locs): # lq_f_32 = self.conv_first_32(lq) lq_f_16 = self.conv_first_16(lq_f_32) lq_f_8 = self.conv_first_8(lq_f_16) sq_f_16 = self.conv_body_16(torch.cat([F.interpolate(lq_f_8, scale_factor=2, mode='bilinear'), lq_f_16], dim=1)) sq_f_32 = self.conv_body_32(torch.cat([F.interpolate(sq_f_16, scale_factor=2, mode='bilinear'), lq_f_32], dim=1)) # if priors32 is not None: sq_f_32_ori = sq_f_32.clone() sq_f_32_res = sq_f_32.clone().detach()*0 for b, p_32 in enumerate(priors32): # p_32_256 = self.conv_32_to256(p_32.clone().detach()) for c in range(p_32_256.size(0)): # center = int(locs[b][c].item()/4.0) #+ random.randint(-2,2)### no backward width = 16 if center < width: x1 = 0 #lq feature left y1 = max(16 - center, 0) else: x1 = center - width y1 = max(16 - width, 0) # y1 = 16 - width if center + width > sq_f_32.size(-1): x2 = sq_f_32.size(-1) #lq feature right else: x2 = center + width y2 = y1 + (x2 - x1) ''' center align ''' y1 = 16 - torch.div(x2-x1, 2, rounding_mode='trunc') y2 = y1 + x2 - x1 char_prior_f = p_32_256[c:c+1, :, :, y1:y2].clone() #prior char_lq_f = sq_f_32[b:b+1, :, :, x1:x2].clone() adain_prior_f = adaptive_instance_normalization(char_prior_f, char_lq_f) fuse_32_prior = self.conv_32_fuse(torch.cat((adain_prior_f, char_lq_f), dim=1)) scale = self.conv_32_scale(fuse_32_prior) shift = self.conv_32_shift(fuse_32_prior) sq_f_32_res[b, :, :, x1:x2] = sq_f_32_res[b, :, :, x1:x2] + sq_f_32[b, :, :, x1:x2].clone() * scale[0,...] + shift[0,...] sq_pf_32_out = sq_f_32_ori + sq_f_32_res else: sq_pf_32_out = sq_f_32.clone() sq_f_64 = self.conv_up(sq_pf_32_out) #64*1024 sq_f_64_ori = sq_f_64.clone() sq_f_64_res = sq_f_64.clone().detach() * 0 # prior_full_64 = sq_f_64.clone().detach() * 0 if priors64 is not None: for b, p_64_prior in enumerate(priors64): # p_64 = p_64_prior.clone().detach() #no backward to prior for c in range(p_64.size(0)): # for each character center = int(locs[b][c].item()/2.0) #+ random.randint(-4,4)### no backward width = 32 if center < width: x1 = 0 y1 = max(32 - center, 0) else: x1 = center -width y1 = max(32 - width, 0) if center + width > sq_f_64.size(-1): x2 = sq_f_64.size(-1) else: x2 = center + width ''' center align ''' y1 = 32 - torch.div(x2-x1, 2, rounding_mode='trunc') y2 = y1 + x2 - x1 char_prior_f = p_64[c:c+1, :, :, y1:y2].clone() char_lq_f = sq_f_64[b:b+1, :, :, x1:x2].clone() adain_prior_f = adaptive_instance_normalization(char_prior_f, char_lq_f) fuse_64_prior = self.conv_64_fuse(torch.cat((adain_prior_f, char_lq_f), dim=1)) scale = self.conv_64_scale(fuse_64_prior) shift = self.conv_64_shift(fuse_64_prior) sq_f_64_res[b, :, :, x1:x2] = sq_f_64_res[b, :, :, x1:x2] + sq_f_64[b, :, :, x1:x2].clone() * scale[0,...] + shift[0,...] # prior_full_64[b, :, :, x1:x2] = prior_full_64[b, :, :, x1:x2] + char_prior_f.clone() sq_pf_64 = sq_f_64_ori + sq_f_64_res else: sq_pf_64 = sq_f_64_ori.clone() f256 = self.conv_final(sq_pf_64) # adain_lr2prior = adaptive_instance_normalization(prior_full_64, F.interpolate(sq_f_32_ori, scale_factor=2, mode='bilinear')) # prior_out = self.conv_priorout(adain_lr2prior) return f256 #prior_out