Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Function | |
from .hourglass import HourGlass | |
from utils.dct import DCT_Lowfrequency | |
from utils.filters_tensor import bgr2gray | |
from collections import OrderedDict | |
import numpy as np | |
class Quantize(Function): | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
y = x.round() | |
return y | |
def backward(ctx, grad_output): | |
inputX = ctx.saved_tensors | |
return grad_output | |
class ResHalf(nn.Module): | |
def __init__(self, train=True, warm_stage=False): | |
super(ResHalf, self).__init__() | |
self.encoder = HourGlass(inChannel=4, outChannel=1, resNum=4, convNum=4) | |
self.decoder = HourGlass(inChannel=1, outChannel=3, resNum=4, convNum=4) | |
self.dcter = DCT_Lowfrequency(size=256, fLimit=50) | |
# quantize [-1,1] data to be {-1,1} | |
self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1. | |
self.isTrain = train | |
if warm_stage: | |
for name, param in self.decoder.named_parameters(): | |
param.requires_grad = False | |
def add_impluse_noise(self, input_halfs, p=0.0): | |
N,C,H,W = input_halfs.shape | |
SNR = 1-p | |
np_input_halfs = input_halfs.detach().to("cpu").numpy() | |
np_input_halfs = np.transpose(np_input_halfs, (0, 2, 3, 1)) | |
for i in range(N): | |
mask = np.random.choice((0, 1, 2), size=(H, W, 1), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.]) | |
np_input_halfs[i, mask==1] = 1.0 | |
np_input_halfs[i, mask==2] = -1.0 | |
return torch.from_numpy(np_input_halfs.transpose((0, 3, 1, 2))).to(input_halfs.device) | |
def forward(self, input_img, decoding_only=False): | |
if decoding_only: | |
halfResQ = self.quantizer(input_img) | |
restored = self.decoder(halfResQ) | |
return restored | |
noise = torch.randn_like(input_img) * 0.3 | |
halfRes = self.encoder(torch.cat((input_img, noise[:,:1,:,:]), dim=1)) | |
halfResQ = self.quantizer(halfRes) | |
restored = self.decoder(halfResQ) | |
if self.isTrain: | |
halfDCT = self.dcter(halfRes / 2. + 0.5) | |
refDCT = self.dcter(bgr2gray(input_img / 2. + 0.5)) | |
return halfRes, halfDCT, refDCT, restored | |
else: | |
return halfRes, restored |