from torchvision.transforms.functional import normalize import torch.nn as nn import numpy as np import os def denormalize(tensor, mean, std): mean = np.array(mean) std = np.array(std) _mean = -mean/std _std = 1/std return normalize(tensor, _mean, _std) class Denormalize(object): def __init__(self, mean, std): mean = np.array(mean) std = np.array(std) self._mean = -mean/std self._std = 1/std def __call__(self, tensor): if isinstance(tensor, np.ndarray): return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1) return normalize(tensor, self._mean, self._std) def set_bn_momentum(model, momentum=0.1): for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.momentum = momentum def fix_bn(model): for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() def mkdir(path): if not os.path.exists(path): os.mkdir(path)