|
""" |
|
Denoise an image with the FFDNet denoising method |
|
|
|
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr> |
|
|
|
This program is free software: you can use, modify and/or |
|
redistribute it under the terms of the GNU General Public |
|
License as published by the Free Software Foundation, either |
|
version 3 of the License, or (at your option) any later |
|
version. You should have received a copy of this license along |
|
this program. If not, see <http://www.gnu.org/licenses/>. |
|
""" |
|
import os |
|
import argparse |
|
import time |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import Variable |
|
from denoising.models import FFDNet |
|
from denoising.utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb |
|
|
|
class FFDNetDenoiser: |
|
def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3): |
|
self.sigma = _sigma / 255 |
|
self.weights_dir = _weights_dir |
|
self.channels = _in_ch |
|
self.device = _device |
|
|
|
self.model = FFDNet(num_input_channels = _in_ch) |
|
self.load_weights() |
|
self.model.eval() |
|
|
|
|
|
def load_weights(self): |
|
weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth' |
|
weights_path = os.path.join(self.weights_dir, weights_name) |
|
if self.device == 'cuda': |
|
state_dict = torch.load(weights_path, map_location=torch.device('cpu')) |
|
device_ids = [0] |
|
self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda() |
|
else: |
|
state_dict = torch.load(weights_path, map_location='cpu') |
|
|
|
state_dict = remove_dataparallel_wrapper(state_dict) |
|
self.model.load_state_dict(state_dict) |
|
|
|
def get_denoised_image(self, imorig, sigma = None): |
|
|
|
if sigma is not None: |
|
cur_sigma = sigma / 255 |
|
else: |
|
cur_sigma = self.sigma |
|
|
|
if len(imorig.shape) < 3 or imorig.shape[2] == 1: |
|
imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2) |
|
|
|
if (max(imorig.shape[0], imorig.shape[1]) > 1200): |
|
ratio = max(imorig.shape[0], imorig.shape[1]) / 1200 |
|
imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA) |
|
|
|
imorig = imorig.transpose(2, 0, 1) |
|
|
|
if (imorig.max() > 1.2): |
|
imorig = normalize(imorig) |
|
imorig = np.expand_dims(imorig, 0) |
|
|
|
|
|
expanded_h = False |
|
expanded_w = False |
|
sh_im = imorig.shape |
|
if sh_im[2]%2 == 1: |
|
expanded_h = True |
|
imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2) |
|
|
|
if sh_im[3]%2 == 1: |
|
expanded_w = True |
|
imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3) |
|
|
|
|
|
imorig = torch.Tensor(imorig) |
|
|
|
|
|
|
|
if self.device == 'cuda': |
|
dtype = torch.cuda.FloatTensor |
|
else: |
|
dtype = torch.FloatTensor |
|
|
|
imnoisy = imorig.clone() |
|
|
|
|
|
with torch.no_grad(): |
|
imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype) |
|
nsigma = torch.FloatTensor([cur_sigma]).type(dtype) |
|
|
|
|
|
|
|
im_noise_estim = self.model(imnoisy, nsigma) |
|
outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.) |
|
|
|
if expanded_h: |
|
imorig = imorig[:, :, :-1, :] |
|
outim = outim[:, :, :-1, :] |
|
imnoisy = imnoisy[:, :, :-1, :] |
|
|
|
if expanded_w: |
|
imorig = imorig[:, :, :, :-1] |
|
outim = outim[:, :, :, :-1] |
|
imnoisy = imnoisy[:, :, :, :-1] |
|
|
|
return variable_to_cv2_image(outim) |
|
|