import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from .hourglass import HourGlass from utils.dct import DCT_Lowfrequency from utils.filters_tensor import bgr2gray from collections import OrderedDict import numpy as np class Quantize(Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) y = x.round() return y @staticmethod def backward(ctx, grad_output): inputX = ctx.saved_tensors return grad_output class ResHalf(nn.Module): def __init__(self, train=True, warm_stage=False): super(ResHalf, self).__init__() self.encoder = HourGlass(inChannel=4, outChannel=1, resNum=4, convNum=4) self.decoder = HourGlass(inChannel=1, outChannel=3, resNum=4, convNum=4) self.dcter = DCT_Lowfrequency(size=256, fLimit=50) # quantize [-1,1] data to be {-1,1} self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1. self.isTrain = train if warm_stage: for name, param in self.decoder.named_parameters(): param.requires_grad = False def add_impluse_noise(self, input_halfs, p=0.0): N,C,H,W = input_halfs.shape SNR = 1-p np_input_halfs = input_halfs.detach().to("cpu").numpy() np_input_halfs = np.transpose(np_input_halfs, (0, 2, 3, 1)) for i in range(N): mask = np.random.choice((0, 1, 2), size=(H, W, 1), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.]) np_input_halfs[i, mask==1] = 1.0 np_input_halfs[i, mask==2] = -1.0 return torch.from_numpy(np_input_halfs.transpose((0, 3, 1, 2))).to(input_halfs.device) def forward(self, input_img, decoding_only=False): if decoding_only: halfResQ = self.quantizer(input_img) restored = self.decoder(halfResQ) return restored noise = torch.randn_like(input_img) * 0.3 halfRes = self.encoder(torch.cat((input_img, noise[:,:1,:,:]), dim=1)) halfResQ = self.quantizer(halfRes) restored = self.decoder(halfResQ) if self.isTrain: halfDCT = self.dcter(halfRes / 2. + 0.5) refDCT = self.dcter(bgr2gray(input_img / 2. + 0.5)) return halfRes, halfDCT, refDCT, restored else: return halfRes, restored