import yaml import glob import os import sys import time import math from datetime import datetime import random import logging from collections import OrderedDict import natsort import numpy as np import cv2 import torch from torchvision.utils import make_grid from shutil import get_terminal_size import torch import torch.nn.functional as F from torch.autograd import Variable import numpy as np from math import exp def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) return gauss / gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm( _1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand( channel, 1, window_size, window_size).contiguous()) return window def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 C1 = 0.01 ** 2 C2 = 0.03 ** 2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = 1 self.window = create_window(window_size, self.channel) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.data.type() == img1.data.type(): window = self.window else: window = create_window(self.window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) self.window = window self.channel = channel return _ssim(img1, img2, window, self.window_size, channel, self.size_average) def ssim(img1, img2, window_size=11, size_average=True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) return _ssim(img1.mean(dim=0, keepdims=True), img2.mean(dim=0, keepdims=True), window, window_size, channel, size_average) try: from yaml import CLoader as Loader, CDumper as Dumper except ImportError: from yaml import Loader, Dumper def OrderedYaml(): '''yaml orderedDict support''' _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG def dict_representer(dumper, data): return dumper.represent_dict(data.items()) def dict_constructor(loader, node): return OrderedDict(loader.construct_pairs(node)) Dumper.add_representer(OrderedDict, dict_representer) Loader.add_constructor(_mapping_tag, dict_constructor) return Loader, Dumper #################### # miscellaneous #################### def get_timestamp(): return datetime.now().strftime('%y%m%d-%H%M%S') def mkdir(path): if not os.path.exists(path): os.makedirs(path) def mkdirs(paths): if isinstance(paths, str): mkdir(paths) else: for path in paths: mkdir(path) def mkdir_and_rename(path): if os.path.exists(path): new_name = path + '_archived_' + get_timestamp() print('Path already exists. Rename it to [{:s}]'.format(new_name)) logger = logging.getLogger('base') logger.info( 'Path already exists. Rename it to [{:s}]'.format(new_name)) os.rename(path, new_name) os.makedirs(path) def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): '''set up logger''' lg = logging.getLogger(logger_name) formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') lg.setLevel(level) if tofile: log_file = os.path.join( root, phase + '_{}.log'.format(get_timestamp())) fh = logging.FileHandler(log_file, mode='w') fh.setFormatter(formatter) lg.addHandler(fh) if screen: sh = logging.StreamHandler() sh.setFormatter(formatter) lg.addHandler(sh) #################### # image convert #################### def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): ''' Converts a torch Tensor into an image Numpy array Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) ''' if hasattr(tensor, 'detach'): tensor = tensor.detach() tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp tensor = (tensor - min_max[0]) / \ (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: n_img = len(tensor) img_np = make_grid(tensor, nrow=int( math.sqrt(n_img)), normalize=False).numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 3: img_np = tensor.numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 2: img_np = tensor.numpy() else: raise TypeError( 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) if out_type == np.uint8: img_np = np.clip((img_np * 255.0).round(), 0, 255) # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) def save_img(img, img_path, mode='RGB'): cv2.imwrite(img_path, img) #################### # metric #################### def calculate_psnr(img1, img2): # img1 and img2 have range [0, 255] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2) ** 2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse)) def get_resume_paths(opt): resume_state_path = None resume_model_path = None ts = opt_get(opt, ['path', 'training_state']) if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None: wildcard = os.path.join(ts, "*") paths = natsort.natsorted(glob.glob(wildcard)) if len(paths) > 0: resume_state_path = paths[-1] resume_model_path = resume_state_path.replace( 'training_state', 'models').replace('.state', '_G.pth') else: resume_state_path = opt.get('path', {}).get('resume_state') return resume_state_path, resume_model_path def opt_get(opt, keys, default=None): if opt is None: return default ret = opt for k in keys: ret = ret.get(k, None) if ret is None: return default return ret