menghanxia's picture
created the space
6e70c4a
raw
history blame
2.39 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from .hourglass import HourGlass
from utils.dct import DCT_Lowfrequency
from utils.filters_tensor import bgr2gray
from collections import OrderedDict
import numpy as np
class Quantize(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
y = x.round()
return y
@staticmethod
def backward(ctx, grad_output):
inputX = ctx.saved_tensors
return grad_output
class ResHalf(nn.Module):
def __init__(self, train=True, warm_stage=False):
super(ResHalf, self).__init__()
self.encoder = HourGlass(inChannel=4, outChannel=1, resNum=4, convNum=4)
self.decoder = HourGlass(inChannel=1, outChannel=3, resNum=4, convNum=4)
self.dcter = DCT_Lowfrequency(size=256, fLimit=50)
# quantize [-1,1] data to be {-1,1}
self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
self.isTrain = train
if warm_stage:
for name, param in self.decoder.named_parameters():
param.requires_grad = False
def add_impluse_noise(self, input_halfs, p=0.0):
N,C,H,W = input_halfs.shape
SNR = 1-p
np_input_halfs = input_halfs.detach().to("cpu").numpy()
np_input_halfs = np.transpose(np_input_halfs, (0, 2, 3, 1))
for i in range(N):
mask = np.random.choice((0, 1, 2), size=(H, W, 1), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.])
np_input_halfs[i, mask==1] = 1.0
np_input_halfs[i, mask==2] = -1.0
return torch.from_numpy(np_input_halfs.transpose((0, 3, 1, 2))).to(input_halfs.device)
def forward(self, input_img, decoding_only=False):
if decoding_only:
halfResQ = self.quantizer(input_img)
restored = self.decoder(halfResQ)
return restored
noise = torch.randn_like(input_img) * 0.3
halfRes = self.encoder(torch.cat((input_img, noise[:,:1,:,:]), dim=1))
halfResQ = self.quantizer(halfRes)
restored = self.decoder(halfResQ)
if self.isTrain:
halfDCT = self.dcter(halfRes / 2. + 0.5)
refDCT = self.dcter(bgr2gray(input_img / 2. + 0.5))
return halfRes, halfDCT, refDCT, restored
else:
return halfRes, restored