diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..3c443bf0a2f108cba50ad512c57e26bfa889bdef --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +pretrained/ohayou_face.pt filter=lfs diff=lfs merge=lfs -text +pretrained/ohayou_face.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b899d5c1af1d3a38a2a76995744b3211adad4de6 --- /dev/null +++ b/README.md @@ -0,0 +1,11 @@ +--- +title: Ohayou_Face +emoji: ⚡ +colorFrom: red +colorTo: yellow +sdk: gradio +app_file: app.py +pinned: false +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1534580f416250dd222d61b8a8c9027c400c6c --- /dev/null +++ b/app.py @@ -0,0 +1,31 @@ +import os +from PIL import Image +import gradio as gr +from torchvision import transforms +import easydict +import torch +import numpy as np +import model_build + + +psp = model_build.build_psp() +stylegan2 = model_build.build_stylegan2() + +pretransform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + +def pipeline(img): + img = model_build.img_preprocess(img, pretransform) + with torch.no_grad(): + _, latent_space = psp(img.float(), randomize_noise=True, resize=False, return_latents=True) + img = stylegan2(latent_space, noise_mode='none') + img = Image.fromarray(np.array((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).squeeze(0)[20:-20,:,:])) + img.save('output.png') + return 'output.png' + +examples=[['momoi_out.png',False], ['churuki_out.png', False], ['fgfgfggf.png', False], ['dsfd.png', False]] +description="The male image doesn't work well. 1:1 ratio image recommended (square cropable after uploading). If the background is not monochromatic, it can be mixed with hair color. It takes an average of 5 seconds, but it can take longer if there is a lot of traffic. 남성 이미지에는 잘 작동하지 않음. 1:1비율 권장(업로드 후 정사각형 자르기 가능), 배경이 단색이 아니면 머리색과 섞일 수 있음. 트래픽이 많으면 5초 이상 걸릴 수 있음. Email:krkmfn@gmail.com" +gr.Interface(pipeline, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),description=description,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=False).launch() + \ No newline at end of file diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/data_configs.py b/configs/data_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1202785b880f85bbc3f79c26876bf54e8dd95f --- /dev/null +++ b/configs/data_configs.py @@ -0,0 +1,41 @@ +from configs import transforms_config +from configs.paths_config import dataset_paths + + +DATASETS = { + 'ffhq_encode': { + 'transforms': transforms_config.EncodeTransforms, + 'train_source_root': dataset_paths['ffhq'], + 'train_target_root': dataset_paths['ffhq'], + 'test_source_root': dataset_paths['celeba_test'], + 'test_target_root': dataset_paths['celeba_test'], + }, + 'furry': { + 'transforms': transforms_config.FrontalizationTransforms, + 'train_source_root': dataset_paths['anime'], + 'train_target_root': dataset_paths['anime'], + 'test_source_root': dataset_paths['gogal'], + 'test_target_root': dataset_paths['gogal'], + }, + 'celebs_sketch_to_face': { + 'transforms': transforms_config.SketchToImageTransforms, + 'train_source_root': dataset_paths['celeba_train_sketch'], + 'train_target_root': dataset_paths['celeba_train'], + 'test_source_root': dataset_paths['celeba_test_sketch'], + 'test_target_root': dataset_paths['celeba_test'], + }, + 'celebs_seg_to_face': { + 'transforms': transforms_config.SegToImageTransforms, + 'train_source_root': dataset_paths['celeba_train_segmentation'], + 'train_target_root': dataset_paths['celeba_train'], + 'test_source_root': dataset_paths['celeba_test_segmentation'], + 'test_target_root': dataset_paths['celeba_test'], + }, + 'celebs_super_resolution': { + 'transforms': transforms_config.SuperResTransforms, + 'train_source_root': dataset_paths['celeba_train'], + 'train_target_root': dataset_paths['celeba_train'], + 'test_source_root': dataset_paths['celeba_test'], + 'test_target_root': dataset_paths['celeba_test'], + }, +} diff --git a/configs/paths_config.py b/configs/paths_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2df26853d8bde0c82967ab208761e2cb8f87e62b --- /dev/null +++ b/configs/paths_config.py @@ -0,0 +1,23 @@ +dataset_paths = { + 'celeba_train': '', + 'celeba_test': '', + 'celeba_train_sketch': '', + 'celeba_test_sketch': '', + 'celeba_train_segmentation': '', + 'celeba_test_segmentation': '', + 'ffhq': '', + 'anime' : '/content/drive/MyDrive/Dataset/anime', + 'gogal' : '/content/drive/MyDrive/All Data/고갈왕' +} + +model_paths = { + 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt', + 'ir_se50': 'pretrained_models/model_ir_se50.pth', + 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth', + 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy', + 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy', + 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy', + 'shape_predictor': 'shape_predictor_68_face_landmarks.dat', + 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar', + 'anime' : '/content/drive/MyDrive/StyleGAN2-ada/result/pretrained/anime_face.pt' +} diff --git a/configs/transforms_config.py b/configs/transforms_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0b37189bfa08c0691897051874a3a35c1d325e85 --- /dev/null +++ b/configs/transforms_config.py @@ -0,0 +1,152 @@ +from abc import abstractmethod +import torchvision.transforms as transforms +from datasets import augmentations + + +class TransformsConfig(object): + + def __init__(self, opts): + self.opts = opts + + @abstractmethod + def get_transforms(self): + pass + + +class EncodeTransforms(TransformsConfig): + + def __init__(self, opts): + super(EncodeTransforms, self).__init__(opts) + + def get_transforms(self): + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': None, + 'transform_test': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } + return transforms_dict + + +class FrontalizationTransforms(TransformsConfig): + + def __init__(self, opts): + super(FrontalizationTransforms, self).__init__(opts) + + def get_transforms(self): + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_test': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } + return transforms_dict + + +class SketchToImageTransforms(TransformsConfig): + + def __init__(self, opts): + super(SketchToImageTransforms, self).__init__(opts) + + def get_transforms(self): + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor()]), + 'transform_test': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor()]), + } + return transforms_dict + + +class SegToImageTransforms(TransformsConfig): + + def __init__(self, opts): + super(SegToImageTransforms, self).__init__(opts) + + def get_transforms(self): + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': transforms.Compose([ + transforms.Resize((256, 256)), + augmentations.ToOneHot(self.opts.label_nc), + transforms.ToTensor()]), + 'transform_test': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((256, 256)), + augmentations.ToOneHot(self.opts.label_nc), + transforms.ToTensor()]) + } + return transforms_dict + + +class SuperResTransforms(TransformsConfig): + + def __init__(self, opts): + super(SuperResTransforms, self).__init__(opts) + + def get_transforms(self): + if self.opts.resize_factors is None: + self.opts.resize_factors = '1,2,4,8,16,32' + factors = [int(f) for f in self.opts.resize_factors.split(",")] + print("Performing down-sampling with factors: {}".format(factors)) + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': transforms.Compose([ + transforms.Resize((256, 256)), + augmentations.BilinearResize(factors=factors), + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_test': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((256, 256)), + augmentations.BilinearResize(factors=factors), + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } + return transforms_dict diff --git a/criteria/__init__.py b/criteria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/criteria/id_loss.py b/criteria/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1608ec1eb575e88035aba73c5b6595b4722db5b8 --- /dev/null +++ b/criteria/id_loss.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from configs.paths_config import model_paths +from models.encoders.model_irse import Backbone + + +class IDLoss(nn.Module): + def __init__(self): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def extract_feats(self, x): + x = x[:, :, 35:223, 32:220] # Crop interesting region + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + id_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + id_logs.append({'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views)}) + loss += 1 - diff_target + id_diff = float(diff_target) - float(diff_views) + sim_improvement += id_diff + count += 1 + + return loss / count, sim_improvement / count, id_logs diff --git a/criteria/lpips/__init__.py b/criteria/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/criteria/lpips/lpips.py b/criteria/lpips/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..1add6acc84c1c04cfcb536cf31ec5acdf24b716b --- /dev/null +++ b/criteria/lpips/lpips.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from criteria.lpips.networks import get_network, LinLayers +from criteria.lpips.utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type).to("cuda") + + # linear layers + self.lin = LinLayers(self.net.n_channels_list).to("cuda") + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0)) / x.shape[0] diff --git a/criteria/lpips/networks.py b/criteria/lpips/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0d13ad2d560278f16586da68d3a5eadb26e746 --- /dev/null +++ b/criteria/lpips/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from criteria.lpips.utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(True).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) \ No newline at end of file diff --git a/criteria/lpips/utils.py b/criteria/lpips/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5 --- /dev/null +++ b/criteria/lpips/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/criteria/moco_loss.py b/criteria/moco_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6f04dde1a929862012395e4b873804ef2bbc00 --- /dev/null +++ b/criteria/moco_loss.py @@ -0,0 +1,69 @@ +import torch +from torch import nn +import torch.nn.functional as F +from configs.paths_config import model_paths + + +class MocoLoss(nn.Module): + + def __init__(self): + super(MocoLoss, self).__init__() + print("Loading MOCO model from path: {}".format(model_paths["moco"])) + self.model = self.__load_model() + self.model.cuda() + self.model.eval() + + @staticmethod + def __load_model(): + import torchvision.models as models + model = models.__dict__["resnet50"]() + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ['fc.weight', 'fc.bias']: + param.requires_grad = False + checkpoint = torch.load(model_paths['moco'], map_location="cpu") + state_dict = checkpoint['state_dict'] + # rename moco pre-trained keys + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): + # remove prefix + state_dict[k[len("module.encoder_q."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + # remove output layer + model = nn.Sequential(*list(model.children())[:-1]).cuda() + return model + + def extract_feats(self, x): + x = F.interpolate(x, size=224) + x_feats = self.model(x) + x_feats = nn.functional.normalize(x_feats, dim=1) + x_feats = x_feats.squeeze() + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + sim_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + sim_logs.append({'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views)}) + loss += 1 - diff_target + sim_diff = float(diff_target) - float(diff_views) + sim_improvement += sim_diff + count += 1 + + return loss / count, sim_improvement / count, sim_logs diff --git a/criteria/w_norm.py b/criteria/w_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..a45ab6f67d8a3f7051be4b7236fa2f38446fd2c1 --- /dev/null +++ b/criteria/w_norm.py @@ -0,0 +1,14 @@ +import torch +from torch import nn + + +class WNormLoss(nn.Module): + + def __init__(self, start_from_latent_avg=True): + super(WNormLoss, self).__init__() + self.start_from_latent_avg = start_from_latent_avg + + def forward(self, latent, latent_avg=None): + if self.start_from_latent_avg: + latent = latent - latent_avg + return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/augmentations.py b/datasets/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0507f155fa32a463b9bd4b2f50099fd1866df0 --- /dev/null +++ b/datasets/augmentations.py @@ -0,0 +1,110 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms + + +class ToOneHot(object): + """ Convert the input PIL image to a one-hot torch tensor """ + def __init__(self, n_classes=None): + self.n_classes = n_classes + + def onehot_initialization(self, a): + if self.n_classes is None: + self.n_classes = len(np.unique(a)) + out = np.zeros(a.shape + (self.n_classes, ), dtype=int) + out[self.__all_idx(a, axis=2)] = 1 + return out + + def __all_idx(self, idx, axis): + grid = np.ogrid[tuple(map(slice, idx.shape))] + grid.insert(axis, idx) + return tuple(grid) + + def __call__(self, img): + img = np.array(img) + one_hot = self.onehot_initialization(img) + return one_hot + + +class BilinearResize(object): + def __init__(self, factors=[1, 2, 4, 8, 16, 32]): + self.factors = factors + + def __call__(self, image): + factor = np.random.choice(self.factors, size=1)[0] + D = BicubicDownSample(factor=factor, cuda=False) + img_tensor = transforms.ToTensor()(image).unsqueeze(0) + img_tensor_lr = D(img_tensor)[0].clamp(0, 1) + img_low_res = transforms.ToPILImage()(img_tensor_lr) + return img_low_res + + +class BicubicDownSample(nn.Module): + def bicubic_kernel(self, x, a=-0.50): + """ + This equation is exactly copied from the website below: + https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic + """ + abs_x = torch.abs(x) + if abs_x <= 1.: + return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1 + elif 1. < abs_x < 2.: + return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a + else: + return 0.0 + + def __init__(self, factor=4, cuda=True, padding='reflect'): + super().__init__() + self.factor = factor + size = factor * 4 + k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) + for i in range(size)], dtype=torch.float32) + k = k / torch.sum(k) + k1 = torch.reshape(k, shape=(1, 1, size, 1)) + self.k1 = torch.cat([k1, k1, k1], dim=0) + k2 = torch.reshape(k, shape=(1, 1, 1, size)) + self.k2 = torch.cat([k2, k2, k2], dim=0) + self.cuda = '.cuda' if cuda else '' + self.padding = padding + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x, nhwc=False, clip_round=False, byte_output=False): + filter_height = self.factor * 4 + filter_width = self.factor * 4 + stride = self.factor + + pad_along_height = max(filter_height - stride, 0) + pad_along_width = max(filter_width - stride, 0) + filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda)) + filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda)) + + # compute actual padding values for each side + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + + # apply mirror padding + if nhwc: + x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW + + # downscaling performed by 1-d convolution + x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) + x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.) + + x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) + x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.) + + if nhwc: + x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) + if byte_output: + return x.type('torch.ByteTensor'.format(self.cuda)) + else: + return x diff --git a/datasets/gt_res_dataset.py b/datasets/gt_res_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8892efabcfad7b902c5d49e4b496001241e7ed99 --- /dev/null +++ b/datasets/gt_res_dataset.py @@ -0,0 +1,32 @@ +#!/usr/bin/python +# encoding: utf-8 +import os +from torch.utils.data import Dataset +from PIL import Image + + +class GTResDataset(Dataset): + + def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): + self.pairs = [] + for f in os.listdir(root_path): + image_path = os.path.join(root_path, f) + gt_path = os.path.join(gt_dir, f) + if f.endswith(".jpg") or f.endswith(".png"): + self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) + self.transform = transform + self.transform_train = transform_train + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + from_path, to_path, _ = self.pairs[index] + from_im = Image.open(from_path).convert('RGB') + to_im = Image.open(to_path).convert('RGB') + + if self.transform: + to_im = self.transform(to_im) + from_im = self.transform(from_im) + + return from_im, to_im diff --git a/datasets/images_dataset.py b/datasets/images_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..62bb3e3eb85f3841696bac02fa5fb217488a43cd --- /dev/null +++ b/datasets/images_dataset.py @@ -0,0 +1,33 @@ +from torch.utils.data import Dataset +from PIL import Image +from utils import data_utils + + +class ImagesDataset(Dataset): + + def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): + self.source_paths = sorted(data_utils.make_dataset(source_root)) + self.target_paths = sorted(data_utils.make_dataset(target_root)) + self.source_transform = source_transform + self.target_transform = target_transform + self.opts = opts + + def __len__(self): + return len(self.source_paths) + + def __getitem__(self, index): + from_path = self.source_paths[index] + from_im = Image.open(from_path) + from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') + + to_path = self.target_paths[index] + to_im = Image.open(to_path).convert('RGB') + if self.target_transform: + to_im = self.target_transform(to_im) + + if self.source_transform: + from_im = self.source_transform(from_im) + else: + from_im = to_im + + return from_im, to_im diff --git a/datasets/inference_dataset.py b/datasets/inference_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de457349b0726932176f21814c61e34f15955bb7 --- /dev/null +++ b/datasets/inference_dataset.py @@ -0,0 +1,22 @@ +from torch.utils.data import Dataset +from PIL import Image +from utils import data_utils + + +class InferenceDataset(Dataset): + + def __init__(self, root, opts, transform=None): + self.paths = sorted(data_utils.make_dataset(root)) + self.transform = transform + self.opts = opts + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + from_path = self.paths[index] + from_im = Image.open(from_path) + from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') + if self.transform: + from_im = self.transform(from_im) + return from_im diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f08cf36f11f9b0fd94c1b7caeadf69b98375b04 --- /dev/null +++ b/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/dnnlib/util.py b/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..76725336d01e75e1c68daa88be47f4fde0bbc63b --- /dev/null +++ b/dnnlib/util.py @@ -0,0 +1,477 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/legacy.py b/legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..ae91b918c7d8c270028d39bac1fa56382d71d162 --- /dev/null +++ b/legacy.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import click +import pickle +import re +import copy +import numpy as np +import torch +import dnnlib +from torch_utils import misc + +#---------------------------------------------------------------------------- + +# !!! custom +def load_network_pkl(f, force_fp16=False, custom=False, **ex_kwargs): +# def load_network_pkl(f, force_fp16=False): + data = _LegacyUnpickler(f).load() + # data = pickle.load(f, encoding='latin1') + + # Legacy TensorFlow pickle => convert. + if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): + tf_G, tf_D, tf_Gs = data + G = convert_tf_generator(tf_G, custom=custom, **ex_kwargs) + D = convert_tf_discriminator(tf_D) + G_ema = convert_tf_generator(tf_Gs, custom=custom, **ex_kwargs) + data = dict(G=G, D=D, G_ema=G_ema) +# !!! custom + assert isinstance(data['G'], torch.nn.Module) + assert isinstance(data['D'], torch.nn.Module) + nets = ['G', 'D', 'G_ema'] + elif isinstance(data, _TFNetworkStub): + G_ema = convert_tf_generator(data, custom=custom, **ex_kwargs) + data = dict(G_ema=G_ema) + nets = ['G_ema'] + else: +# !!! custom + if custom is True: + G_ema = custom_generator(data, **ex_kwargs) + data = dict(G_ema=G_ema) + nets = ['G_ema'] + else: + nets = [] + for name in ['G', 'D', 'G_ema']: + if name in data.keys(): + nets.append(name) + # print(nets) + + # Add missing fields. + if 'training_set_kwargs' not in data: + data['training_set_kwargs'] = None + if 'augment_pipe' not in data: + data['augment_pipe'] = None + + # Validate contents. + assert isinstance(data['G_ema'], torch.nn.Module) + assert isinstance(data['training_set_kwargs'], (dict, type(None))) + assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) + + # Force FP16. + if force_fp16: + for key in nets: # !!! custom + old = data[key] + kwargs = copy.deepcopy(old.init_kwargs) + if key.startswith('G'): + kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) + kwargs.synthesis_kwargs.num_fp16_res = 4 + kwargs.synthesis_kwargs.conv_clamp = 256 + if key.startswith('D'): + kwargs.num_fp16_res = 4 + kwargs.conv_clamp = 256 + if kwargs != old.init_kwargs: + new = type(old)(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(old, new, require_all=True) + data[key] = new + return data + +#---------------------------------------------------------------------------- + +class _TFNetworkStub(dnnlib.EasyDict): + pass + +class _LegacyUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'dnnlib.tflib.network' and name == 'Network': + return _TFNetworkStub + return super().find_class(module, name) + +#---------------------------------------------------------------------------- + +def _collect_tf_params(tf_net): + # pylint: disable=protected-access + tf_params = dict() + def recurse(prefix, tf_net): + for name, value in tf_net.variables: + tf_params[prefix + name] = value + for name, comp in tf_net.components.items(): + recurse(prefix + name + '/', comp) + recurse('', tf_net) + return tf_params + +#---------------------------------------------------------------------------- + +def _populate_module_params(module, *patterns): + for name, tensor in misc.named_params_and_buffers(module): + found = False + value = None + for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): + match = re.fullmatch(pattern, name) + if match: + found = True + if value_fn is not None: + value = value_fn(*match.groups()) + break + try: + assert found + if value is not None: + tensor.copy_(torch.from_numpy(np.array(value))) + except: + print(name, list(tensor.shape)) + raise + +#---------------------------------------------------------------------------- + +# !!! custom +def custom_generator(data, **ex_kwargs): + from training import stylegan2_multi as networks + try: # saved? (with new fix) + fmap_base = data['G_ema'].synthesis.fmap_base + except: # default from original configs + fmap_base = 32768 if data['G_ema'].img_resolution >= 512 else 16384 + kwargs = dnnlib.EasyDict( + z_dim = data['G_ema'].z_dim, + c_dim = data['G_ema'].c_dim, + w_dim = data['G_ema'].w_dim, + img_resolution = data['G_ema'].img_resolution, + img_channels = data['G_ema'].img_channels, + init_res = [4,4], # hacky + mapping_kwargs = dnnlib.EasyDict(num_layers = data['G_ema'].mapping.num_layers), + synthesis_kwargs = dnnlib.EasyDict(channel_base = fmap_base, **ex_kwargs), + ) + G_out = networks.Generator(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(data['G_ema'], G_out, require_all=False) + return G_out + +# !!! custom +def convert_tf_generator(tf_G, custom=False, **ex_kwargs): +# def convert_tf_generator(tf_G): + if tf_G.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_G.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None, none=None): + known_kwargs.add(tf_name) + val = tf_kwargs.get(tf_name, default) + return val if val is not None else none + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + z_dim = kwarg('latent_size', 512), + c_dim = kwarg('label_size', 0), + w_dim = kwarg('dlatent_size', 512), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 8), + embed_features = kwarg('label_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('mapping_nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.01), + w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), + ), + synthesis_kwargs = dnnlib.EasyDict( + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + architecture = kwarg('architecture', 'skip'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + use_noise = kwarg('use_noise', True), + activation = kwarg('nonlinearity', 'lrelu'), + ), +# !!! custom + # init_res = kwarg('init_res', [4,4]), + ) + + # Check for unknown kwargs. + kwarg('truncation_psi') + kwarg('truncation_cutoff') + kwarg('style_mixing_prob') + kwarg('structure') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) +# !!! custom + if custom: + kwargs.init_res = [4,4] + kwargs.synthesis_kwargs = dnnlib.EasyDict(**kwargs.synthesis_kwargs, **ex_kwargs) + if len(unknown_kwargs) > 0: + print('Unknown TensorFlow data! This may result in problems with your converted model.') + print(unknown_kwargs) + #raise ValueError('Unknown TensorFlow kwargs:', unknown_kwargs) + # raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + # try: + # if ex_kwargs['verbose'] is True: print(kwargs.synthesis_kwargs) + # except: pass + + # Collect params. + tf_params = _collect_tf_params(tf_G) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value + kwargs.synthesis.kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + if custom: + from training import stylegan2_multi as networks + else: + from training import networks + G = networks.Generator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + _populate_module_params(G, + r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], + r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], + r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], + r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], + r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], + r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], + r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), + r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], + r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], + r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], + r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], + r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], + r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], + r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, + r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], + r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, + r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'.*\.resample_filter', None, + ) + return G + +#---------------------------------------------------------------------------- + +def convert_tf_discriminator(tf_D): + if tf_D.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_D.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None): + known_kwargs.add(tf_name) + return tf_kwargs.get(tf_name, default) + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + c_dim = kwarg('label_size', 0), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + architecture = kwarg('architecture', 'resnet'), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + cmap_dim = kwarg('mapping_fmaps', None), + block_kwargs = dnnlib.EasyDict( + activation = kwarg('nonlinearity', 'lrelu'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + freeze_layers = kwarg('freeze_layers', 0), + ), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 0), + embed_features = kwarg('mapping_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.1), + ), + epilogue_kwargs = dnnlib.EasyDict( + mbstd_group_size = kwarg('mbstd_group_size', None), + mbstd_num_channels = kwarg('mbstd_num_features', 1), + activation = kwarg('nonlinearity', 'lrelu'), + ), +# !!! custom + # init_res = kwarg('init_res', [4,4]), + ) + + # Check for unknown kwargs. + kwarg('structure') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + print('Unknown TensorFlow data! This may result in problems with your converted model.') + print(unknown_kwargs) + # originally this repo threw errors: + # raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_D) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value + kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + from training import networks + D = networks.Discriminator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + _populate_module_params(D, + r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], + r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], + r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), + r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], + r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], + r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), + r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], + r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), + r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], + r'.*\.resample_filter', None, + ) + return D + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--source', help='Input pickle', required=True, metavar='PATH') +@click.option('--dest', help='Output pickle', required=True, metavar='PATH') +@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) +def convert_network_pickle(source, dest, force_fp16): + """Convert legacy network pickle into the native PyTorch format. + + The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. + It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. + + Example: + + \b + python legacy.py \\ + --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ + --dest=stylegan2-cat-config-f.pkl + """ + print(f'Loading "{source}"...') + with dnnlib.util.open_url(source) as f: + data = load_network_pkl(f, force_fp16=force_fp16) + print(f'Saving "{dest}"...') + with open(dest, 'wb') as f: + pickle.dump(data, f) + print('Done.') + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_network_pickle() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/model_build.py b/model_build.py new file mode 100644 index 0000000000000000000000000000000000000000..b06d8ef428ad7856adc72f49caff2652cc6f7f26 --- /dev/null +++ b/model_build.py @@ -0,0 +1,95 @@ +import os +import glob + +import numpy as np +from numpy import linalg +import PIL.Image as Image +import torch +from torchvision import transforms +from tqdm import tqdm +from argparse import Namespace +import easydict + +import legacy +import dnnlib + +from opensimplex import OpenSimplex + +from configs import data_configs +from models.psp import pSp + + +def build_stylegan2( + increment = 0.01, + network_pkl = 'pretrained/furry.pkl', + process = 'image', #['image', 'interpolation','truncation','interpolation-truncation'] + random_seed = 0, + diameter = 100.0, + scale_type = 'pad', #['pad', 'padside', 'symm','symmside'] + size = [512, 512], + seeds = [0], + space = 'z', #['z', 'w'] + fps = 24, + frames = 240, + noise_mode = 'none', #['const', 'random', 'none'] + outdir = 'path', + projected_w = 'path', + easing = 'linear', + device = 'cpu' + + ): + + G_kwargs = dnnlib.EasyDict() + G_kwargs.size = size + G_kwargs.scale_type = scale_type + + device = torch.device(device) + with dnnlib.util.open_url(network_pkl) as f: + # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore + + return G.synthesis + + +def build_psp(): + test_opts = easydict.EasyDict({ + # arguments for inference script + 'checkpoint_path' : 'pretrained/psp.pt', + 'couple_outputs' : False, + 'resize_outputs' : False, + + 'test_batch_size' : 1, + 'test_workers' : 1, + + # arguments for style-mixing script + 'n_images' : None, + 'n_outputs_to_generate' : 5, + 'mix_alpha' : None, + 'latent_mask' : None, + + # arguments for super-resolution + 'resize_factors' : None, + }) + + # update test options with options used during training + ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') + opts = ckpt['opts'] + opts.update(vars(test_opts)) + if 'learn_in_w' not in opts: + opts['learn_in_w'] = False + opts = Namespace(**opts) + opts.device = 'cpu' + net = pSp(opts) + net.eval() + return net + +def img_preprocess(img, transform): + if (img.mode == 'RGBA') or (img.mode == 'P'): + img.load() + background = Image.new("RGB", img.size, (255, 255, 255)) + background.paste(img, mask=img.split()[3]) # 3 is the alpha channel + img = background + assert img.mode == 'RGB' + img = transform(img) + img = img.unsqueeze(dim=0) + return img \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/encoders/__init__.py b/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/encoders/helpers.py b/models/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b51fdf97141407fcc1c9d249a086ddbfd042469f --- /dev/null +++ b/models/encoders/helpers.py @@ -0,0 +1,119 @@ +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/models/encoders/model_irse.py b/models/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..bc41ace0ba04cf4285c283a28e6c36113a18e6d6 --- /dev/null +++ b/models/encoders/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/models/encoders/psp_encoders.py b/models/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..9323f06854173ca55d0c1b693cf745c01a84effb --- /dev/null +++ b/models/encoders/psp_encoders.py @@ -0,0 +1,186 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module + +from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE +from models.stylegan2.model import EqualLinear + + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + self.style_count = opts.n_styles + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def _upsample_add(self, x, y): + '''Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + ''' + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = self._upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = self._upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out + + +class BackboneEncoderUsingLastLayerIntoW(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoW, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoW') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = EqualLinear(512, 512, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_pool(x) + x = x.view(-1, 512) + x = self.linear(x) + return x + + +class BackboneEncoderUsingLastLayerIntoWPlus(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoWPlus') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.n_styles = opts.n_styles + self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + self.output_layer_2 = Sequential(BatchNorm2d(512), + torch.nn.AdaptiveAvgPool2d((7, 7)), + Flatten(), + Linear(512 * 7 * 7, 512)) + self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer_2(x) + x = self.linear(x) + x = x.view(-1, self.n_styles, 512) + return x diff --git a/models/mtcnn/__init__.py b/models/mtcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/mtcnn/mtcnn.py b/models/mtcnn/mtcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4deacabaaf35e315c363c9eada9ff0c41f2561e5 --- /dev/null +++ b/models/mtcnn/mtcnn.py @@ -0,0 +1,156 @@ +import numpy as np +import torch +from PIL import Image +from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet +from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square +from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage +from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face + +device = 'cuda:0' + + +class MTCNN(): + def __init__(self): + print(device) + self.pnet = PNet().to(device) + self.rnet = RNet().to(device) + self.onet = ONet().to(device) + self.pnet.eval() + self.rnet.eval() + self.onet.eval() + self.refrence = get_reference_facial_points(default_square=True) + + def align(self, img): + _, landmarks = self.detect_faces(img) + if len(landmarks) == 0: + return None, None + facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] + warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) + return Image.fromarray(warped_face), tfm + + def align_multi(self, img, limit=None, min_face_size=30.0): + boxes, landmarks = self.detect_faces(img, min_face_size) + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + faces = [] + tfms = [] + for landmark in landmarks: + facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] + warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) + faces.append(Image.fromarray(warped_face)) + tfms.append(tfm) + return boxes, faces, tfms + + def detect_faces(self, image, min_face_size=20.0, + thresholds=[0.15, 0.25, 0.35], + nms_thresholds=[0.7, 0.7, 0.7]): + """ + Arguments: + image: an instance of PIL.Image. + min_face_size: a float number. + thresholds: a list of length 3. + nms_thresholds: a list of length 3. + + Returns: + two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], + bounding boxes and facial landmarks. + """ + + # BUILD AN IMAGE PYRAMID + width, height = image.size + min_length = min(height, width) + + min_detection_size = 12 + factor = 0.707 # sqrt(0.5) + + # scales for scaling the image + scales = [] + + # scales the image so that + # minimum size that we can detect equals to + # minimum face size that we want to detect + m = min_detection_size / min_face_size + min_length *= m + + factor_count = 0 + while min_length > min_detection_size: + scales.append(m * factor ** factor_count) + min_length *= factor + factor_count += 1 + + # STAGE 1 + + # it will be returned + bounding_boxes = [] + + with torch.no_grad(): + # run P-Net on different scales + for s in scales: + boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) + bounding_boxes.append(boxes) + + # collect boxes (and offsets, and scores) from different scales + bounding_boxes = [i for i in bounding_boxes if i is not None] + bounding_boxes = np.vstack(bounding_boxes) + + keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) + bounding_boxes = bounding_boxes[keep] + + # use offsets predicted by pnet to transform bounding boxes + bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) + # shape [n_boxes, 5] + + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 2 + + img_boxes = get_image_boxes(bounding_boxes, image, size=24) + img_boxes = torch.FloatTensor(img_boxes).to(device) + + output = self.rnet(img_boxes) + offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] + probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[1])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) + offsets = offsets[keep] + + keep = nms(bounding_boxes, nms_thresholds[1]) + bounding_boxes = bounding_boxes[keep] + bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 3 + + img_boxes = get_image_boxes(bounding_boxes, image, size=48) + if len(img_boxes) == 0: + return [], [] + img_boxes = torch.FloatTensor(img_boxes).to(device) + output = self.onet(img_boxes) + landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] + offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] + probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[2])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) + offsets = offsets[keep] + landmarks = landmarks[keep] + + # compute landmark points + width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 + height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 + xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] + landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] + landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] + + bounding_boxes = calibrate_box(bounding_boxes, offsets) + keep = nms(bounding_boxes, nms_thresholds[2], mode='min') + bounding_boxes = bounding_boxes[keep] + landmarks = landmarks[keep] + + return bounding_boxes, landmarks diff --git a/models/mtcnn/mtcnn_pytorch/__init__.py b/models/mtcnn/mtcnn_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/mtcnn/mtcnn_pytorch/src/__init__.py b/models/mtcnn/mtcnn_pytorch/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..617ba38c34b1801b2db2e0209b4e886c9d24c490 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/__init__.py @@ -0,0 +1,2 @@ +from .visualization_utils import show_bboxes +from .detector import detect_faces diff --git a/models/mtcnn/mtcnn_pytorch/src/align_trans.py b/models/mtcnn/mtcnn_pytorch/src/align_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5f1df002bc19556ae8a75cabf56310084785a9 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/align_trans.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 24 15:43:29 2017 +@author: zhaoy +""" +import numpy as np +import cv2 + +# from scipy.linalg import lstsq +# from scipy.ndimage import geometric_transform # , map_coordinates + +from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [ + [30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, 71.73660278], + [33.54930115, 92.3655014], + [62.72990036, 92.20410156] +] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + def __str__(self): + return 'In File {}:{}'.format( + __file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + # print('\n===> get_reference_facial_points():') + + # print('---> Params:') + # print(' output_size: ', output_size) + # print(' inner_padding_factor: ', inner_padding_factor) + # print(' outer_padding:', outer_padding) + # print(' default_square: ', default_square) + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + # print('---> default:') + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + if (output_size and + output_size[0] == tmp_crop_size[0] and + output_size[1] == tmp_crop_size[1]): + # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) + return tmp_5pts + + if (inner_padding_factor == 0 and + outer_padding == (0, 0)): + if output_size is None: + # print('No paddings to do: return default reference points') + return tmp_5pts + else: + raise FaceWarpException( + 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) + and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + # print(' deduced from paddings, output_size = ', output_size) + + if not (outer_padding[0] < output_size[0] + and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0]' + 'and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + # print('---> STEP1: pad the inner region according inner_padding_factor') + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 2) resize the padded inner region + # print('---> STEP2: resize the padded inner region') + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + # print(' crop_size = ', tmp_crop_size) + # print(' size_bf_outer_pad = ', size_bf_outer_pad) + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + # print(' resize scale_factor = ', scale_factor) + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + # print('---> STEP3: add outer_padding to make output_size') + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # print('===> end get_reference_facial_points\n') + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + # #print(('src_pts_:\n' + str(src_pts_)) + # #print(('dst_pts_:\n' + str(dst_pts_)) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + # #print(('np.linalg.lstsq return A: \n' + str(A)) + # #print(('np.linalg.lstsq return res: \n' + str(res)) + # #print(('np.linalg.lstsq return rank: \n' + str(rank)) + # #print(('np.linalg.lstsq return s: \n' + str(s)) + + if rank == 3: + tfm = np.float32([ + [A[0, 0], A[1, 0], A[2, 0]], + [A[0, 1], A[1, 1], A[2, 1]] + ]) + elif rank == 2: + tfm = np.float32([ + [A[0, 0], A[1, 0], 0], + [A[0, 1], A[1, 1], 0] + ]) + + return tfm + + +def warp_and_crop_face(src_img, + facial_pts, + reference_pts=None, + crop_size=(96, 112), + align_type='smilarity'): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, + inner_padding_factor, + outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + # #print('--->src_pts:\n', src_pts + # #print('--->ref_pts\n', ref_pts + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type is 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm)) + elif align_type is 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm)) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm)) + + # #print('--->Transform matrix: ' + # #print(('type(tfm):' + str(type(tfm))) + # #print(('tfm.dtype:' + str(tfm.dtype)) + # #print( tfm + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img, tfm diff --git a/models/mtcnn/mtcnn_pytorch/src/box_utils.py b/models/mtcnn/mtcnn_pytorch/src/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8081b73639a7d70e4391b3d45417569550ddc6 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/box_utils.py @@ -0,0 +1,238 @@ +import numpy as np +from PIL import Image + + +def nms(boxes, overlap_threshold=0.5, mode='union'): + """Non-maximum suppression. + + Arguments: + boxes: a float numpy array of shape [n, 5], + where each row is (xmin, ymin, xmax, ymax, score). + overlap_threshold: a float number. + mode: 'union' or 'min'. + + Returns: + list with indices of the selected boxes + """ + + # if there are no boxes, return the empty list + if len(boxes) == 0: + return [] + + # list of picked indices + pick = [] + + # grab the coordinates of the bounding boxes + x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] + + area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) + ids = np.argsort(score) # in increasing order + + while len(ids) > 0: + + # grab index of the largest value + last = len(ids) - 1 + i = ids[last] + pick.append(i) + + # compute intersections + # of the box with the largest score + # with the rest of boxes + + # left top corner of intersection boxes + ix1 = np.maximum(x1[i], x1[ids[:last]]) + iy1 = np.maximum(y1[i], y1[ids[:last]]) + + # right bottom corner of intersection boxes + ix2 = np.minimum(x2[i], x2[ids[:last]]) + iy2 = np.minimum(y2[i], y2[ids[:last]]) + + # width and height of intersection boxes + w = np.maximum(0.0, ix2 - ix1 + 1.0) + h = np.maximum(0.0, iy2 - iy1 + 1.0) + + # intersections' areas + inter = w * h + if mode == 'min': + overlap = inter / np.minimum(area[i], area[ids[:last]]) + elif mode == 'union': + # intersection over union (IoU) + overlap = inter / (area[i] + area[ids[:last]] - inter) + + # delete all boxes where overlap is too big + ids = np.delete( + ids, + np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) + ) + + return pick + + +def convert_to_square(bboxes): + """Convert bounding boxes to a square form. + + Arguments: + bboxes: a float numpy array of shape [n, 5]. + + Returns: + a float numpy array of shape [n, 5], + squared bounding boxes. + """ + + square_bboxes = np.zeros_like(bboxes) + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + h = y2 - y1 + 1.0 + w = x2 - x1 + 1.0 + max_side = np.maximum(h, w) + square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 + square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 + square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 + square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 + return square_bboxes + + +def calibrate_box(bboxes, offsets): + """Transform bounding boxes to be more like true bounding boxes. + 'offsets' is one of the outputs of the nets. + + Arguments: + bboxes: a float numpy array of shape [n, 5]. + offsets: a float numpy array of shape [n, 4]. + + Returns: + a float numpy array of shape [n, 5]. + """ + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + w = x2 - x1 + 1.0 + h = y2 - y1 + 1.0 + w = np.expand_dims(w, 1) + h = np.expand_dims(h, 1) + + # this is what happening here: + # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] + # x1_true = x1 + tx1*w + # y1_true = y1 + ty1*h + # x2_true = x2 + tx2*w + # y2_true = y2 + ty2*h + # below is just more compact form of this + + # are offsets always such that + # x1 < x2 and y1 < y2 ? + + translation = np.hstack([w, h, w, h]) * offsets + bboxes[:, 0:4] = bboxes[:, 0:4] + translation + return bboxes + + +def get_image_boxes(bounding_boxes, img, size=24): + """Cut out boxes from the image. + + Arguments: + bounding_boxes: a float numpy array of shape [n, 5]. + img: an instance of PIL.Image. + size: an integer, size of cutouts. + + Returns: + a float numpy array of shape [n, 3, size, size]. + """ + + num_boxes = len(bounding_boxes) + width, height = img.size + + [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height) + img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') + + for i in range(num_boxes): + img_box = np.zeros((h[i], w[i], 3), 'uint8') + + img_array = np.asarray(img, 'uint8') + img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \ + img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] + + # resize + img_box = Image.fromarray(img_box) + img_box = img_box.resize((size, size), Image.BILINEAR) + img_box = np.asarray(img_box, 'float32') + + img_boxes[i, :, :, :] = _preprocess(img_box) + + return img_boxes + + +def correct_bboxes(bboxes, width, height): + """Crop boxes that are too big and get coordinates + with respect to cutouts. + + Arguments: + bboxes: a float numpy array of shape [n, 5], + where each row is (xmin, ymin, xmax, ymax, score). + width: a float number. + height: a float number. + + Returns: + dy, dx, edy, edx: a int numpy arrays of shape [n], + coordinates of the boxes with respect to the cutouts. + y, x, ey, ex: a int numpy arrays of shape [n], + corrected ymin, xmin, ymax, xmax. + h, w: a int numpy arrays of shape [n], + just heights and widths of boxes. + + in the following order: + [dy, edy, dx, edx, y, ey, x, ex, w, h]. + """ + + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 + num_boxes = bboxes.shape[0] + + # 'e' stands for end + # (x, y) -> (ex, ey) + x, y, ex, ey = x1, y1, x2, y2 + + # we need to cut out a box from the image. + # (x, y, ex, ey) are corrected coordinates of the box + # in the image. + # (dx, dy, edx, edy) are coordinates of the box in the cutout + # from the image. + dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) + edx, edy = w.copy() - 1.0, h.copy() - 1.0 + + # if box's bottom right corner is too far right + ind = np.where(ex > width - 1.0)[0] + edx[ind] = w[ind] + width - 2.0 - ex[ind] + ex[ind] = width - 1.0 + + # if box's bottom right corner is too low + ind = np.where(ey > height - 1.0)[0] + edy[ind] = h[ind] + height - 2.0 - ey[ind] + ey[ind] = height - 1.0 + + # if box's top left corner is too far left + ind = np.where(x < 0.0)[0] + dx[ind] = 0.0 - x[ind] + x[ind] = 0.0 + + # if box's top left corner is too high + ind = np.where(y < 0.0)[0] + dy[ind] = 0.0 - y[ind] + y[ind] = 0.0 + + return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] + return_list = [i.astype('int32') for i in return_list] + + return return_list + + +def _preprocess(img): + """Preprocessing step before feeding the network. + + Arguments: + img: a float numpy array of shape [h, w, c]. + + Returns: + a float numpy array of shape [1, c, h, w]. + """ + img = img.transpose((2, 0, 1)) + img = np.expand_dims(img, 0) + img = (img - 127.5) * 0.0078125 + return img diff --git a/models/mtcnn/mtcnn_pytorch/src/detector.py b/models/mtcnn/mtcnn_pytorch/src/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b162cff3194cc0114abd1a840e5dc772a55edd25 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/detector.py @@ -0,0 +1,126 @@ +import numpy as np +import torch +from torch.autograd import Variable +from .get_nets import PNet, RNet, ONet +from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square +from .first_stage import run_first_stage + + +def detect_faces(image, min_face_size=20.0, + thresholds=[0.6, 0.7, 0.8], + nms_thresholds=[0.7, 0.7, 0.7]): + """ + Arguments: + image: an instance of PIL.Image. + min_face_size: a float number. + thresholds: a list of length 3. + nms_thresholds: a list of length 3. + + Returns: + two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], + bounding boxes and facial landmarks. + """ + + # LOAD MODELS + pnet = PNet() + rnet = RNet() + onet = ONet() + onet.eval() + + # BUILD AN IMAGE PYRAMID + width, height = image.size + min_length = min(height, width) + + min_detection_size = 12 + factor = 0.707 # sqrt(0.5) + + # scales for scaling the image + scales = [] + + # scales the image so that + # minimum size that we can detect equals to + # minimum face size that we want to detect + m = min_detection_size / min_face_size + min_length *= m + + factor_count = 0 + while min_length > min_detection_size: + scales.append(m * factor ** factor_count) + min_length *= factor + factor_count += 1 + + # STAGE 1 + + # it will be returned + bounding_boxes = [] + + with torch.no_grad(): + # run P-Net on different scales + for s in scales: + boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) + bounding_boxes.append(boxes) + + # collect boxes (and offsets, and scores) from different scales + bounding_boxes = [i for i in bounding_boxes if i is not None] + bounding_boxes = np.vstack(bounding_boxes) + + keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) + bounding_boxes = bounding_boxes[keep] + + # use offsets predicted by pnet to transform bounding boxes + bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) + # shape [n_boxes, 5] + + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 2 + + img_boxes = get_image_boxes(bounding_boxes, image, size=24) + img_boxes = torch.FloatTensor(img_boxes) + + output = rnet(img_boxes) + offsets = output[0].data.numpy() # shape [n_boxes, 4] + probs = output[1].data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[1])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) + offsets = offsets[keep] + + keep = nms(bounding_boxes, nms_thresholds[1]) + bounding_boxes = bounding_boxes[keep] + bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 3 + + img_boxes = get_image_boxes(bounding_boxes, image, size=48) + if len(img_boxes) == 0: + return [], [] + img_boxes = torch.FloatTensor(img_boxes) + output = onet(img_boxes) + landmarks = output[0].data.numpy() # shape [n_boxes, 10] + offsets = output[1].data.numpy() # shape [n_boxes, 4] + probs = output[2].data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[2])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) + offsets = offsets[keep] + landmarks = landmarks[keep] + + # compute landmark points + width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 + height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 + xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] + landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] + landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] + + bounding_boxes = calibrate_box(bounding_boxes, offsets) + keep = nms(bounding_boxes, nms_thresholds[2], mode='min') + bounding_boxes = bounding_boxes[keep] + landmarks = landmarks[keep] + + return bounding_boxes, landmarks diff --git a/models/mtcnn/mtcnn_pytorch/src/first_stage.py b/models/mtcnn/mtcnn_pytorch/src/first_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..d646f91d5e0348e23bd426701f6afa6000a9b6d1 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/first_stage.py @@ -0,0 +1,101 @@ +import torch +from torch.autograd import Variable +import math +from PIL import Image +import numpy as np +from .box_utils import nms, _preprocess + +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = 'cuda:0' + + +def run_first_stage(image, net, scale, threshold): + """Run P-Net, generate bounding boxes, and do NMS. + + Arguments: + image: an instance of PIL.Image. + net: an instance of pytorch's nn.Module, P-Net. + scale: a float number, + scale width and height of the image by this number. + threshold: a float number, + threshold on the probability of a face when generating + bounding boxes from predictions of the net. + + Returns: + a float numpy array of shape [n_boxes, 9], + bounding boxes with scores and offsets (4 + 1 + 4). + """ + + # scale the image and convert it to a float array + width, height = image.size + sw, sh = math.ceil(width * scale), math.ceil(height * scale) + img = image.resize((sw, sh), Image.BILINEAR) + img = np.asarray(img, 'float32') + + img = torch.FloatTensor(_preprocess(img)).to(device) + with torch.no_grad(): + output = net(img) + probs = output[1].cpu().data.numpy()[0, 1, :, :] + offsets = output[0].cpu().data.numpy() + # probs: probability of a face at each sliding window + # offsets: transformations to true bounding boxes + + boxes = _generate_bboxes(probs, offsets, scale, threshold) + if len(boxes) == 0: + return None + + keep = nms(boxes[:, 0:5], overlap_threshold=0.5) + return boxes[keep] + + +def _generate_bboxes(probs, offsets, scale, threshold): + """Generate bounding boxes at places + where there is probably a face. + + Arguments: + probs: a float numpy array of shape [n, m]. + offsets: a float numpy array of shape [1, 4, n, m]. + scale: a float number, + width and height of the image were scaled by this number. + threshold: a float number. + + Returns: + a float numpy array of shape [n_boxes, 9] + """ + + # applying P-Net is equivalent, in some sense, to + # moving 12x12 window with stride 2 + stride = 2 + cell_size = 12 + + # indices of boxes where there is probably a face + inds = np.where(probs > threshold) + + if inds[0].size == 0: + return np.array([]) + + # transformations of bounding boxes + tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] + # they are defined as: + # w = x2 - x1 + 1 + # h = y2 - y1 + 1 + # x1_true = x1 + tx1*w + # x2_true = x2 + tx2*w + # y1_true = y1 + ty1*h + # y2_true = y2 + ty2*h + + offsets = np.array([tx1, ty1, tx2, ty2]) + score = probs[inds[0], inds[1]] + + # P-Net is applied to scaled images + # so we need to rescale bounding boxes back + bounding_boxes = np.vstack([ + np.round((stride * inds[1] + 1.0) / scale), + np.round((stride * inds[0] + 1.0) / scale), + np.round((stride * inds[1] + 1.0 + cell_size) / scale), + np.round((stride * inds[0] + 1.0 + cell_size) / scale), + score, offsets + ]) + # why one is added? + + return bounding_boxes.T diff --git a/models/mtcnn/mtcnn_pytorch/src/get_nets.py b/models/mtcnn/mtcnn_pytorch/src/get_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5d3cc64734f0d05b19969fda31dc2bff9b18c6 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/get_nets.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +import numpy as np + +from configs.paths_config import model_paths +PNET_PATH = model_paths["mtcnn_pnet"] +ONET_PATH = model_paths["mtcnn_onet"] +RNET_PATH = model_paths["mtcnn_rnet"] + + +class Flatten(nn.Module): + + def __init__(self): + super(Flatten, self).__init__() + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, c, h, w]. + Returns: + a float tensor with shape [batch_size, c*h*w]. + """ + + # without this pretrained model isn't working + x = x.transpose(3, 2).contiguous() + + return x.view(x.size(0), -1) + + +class PNet(nn.Module): + + def __init__(self): + super().__init__() + + # suppose we have input with size HxW, then + # after first layer: H - 2, + # after pool: ceil((H - 2)/2), + # after second conv: ceil((H - 2)/2) - 2, + # after last conv: ceil((H - 2)/2) - 4, + # and the same for W + + self.features = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(3, 10, 3, 1)), + ('prelu1', nn.PReLU(10)), + ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), + + ('conv2', nn.Conv2d(10, 16, 3, 1)), + ('prelu2', nn.PReLU(16)), + + ('conv3', nn.Conv2d(16, 32, 3, 1)), + ('prelu3', nn.PReLU(32)) + ])) + + self.conv4_1 = nn.Conv2d(32, 2, 1, 1) + self.conv4_2 = nn.Conv2d(32, 4, 1, 1) + + weights = np.load(PNET_PATH, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + b: a float tensor with shape [batch_size, 4, h', w']. + a: a float tensor with shape [batch_size, 2, h', w']. + """ + x = self.features(x) + a = self.conv4_1(x) + b = self.conv4_2(x) + a = F.softmax(a, dim=-1) + return b, a + + +class RNet(nn.Module): + + def __init__(self): + super().__init__() + + self.features = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(3, 28, 3, 1)), + ('prelu1', nn.PReLU(28)), + ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), + + ('conv2', nn.Conv2d(28, 48, 3, 1)), + ('prelu2', nn.PReLU(48)), + ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), + + ('conv3', nn.Conv2d(48, 64, 2, 1)), + ('prelu3', nn.PReLU(64)), + + ('flatten', Flatten()), + ('conv4', nn.Linear(576, 128)), + ('prelu4', nn.PReLU(128)) + ])) + + self.conv5_1 = nn.Linear(128, 2) + self.conv5_2 = nn.Linear(128, 4) + + weights = np.load(RNET_PATH, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + b: a float tensor with shape [batch_size, 4]. + a: a float tensor with shape [batch_size, 2]. + """ + x = self.features(x) + a = self.conv5_1(x) + b = self.conv5_2(x) + a = F.softmax(a, dim=-1) + return b, a + + +class ONet(nn.Module): + + def __init__(self): + super().__init__() + + self.features = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(3, 32, 3, 1)), + ('prelu1', nn.PReLU(32)), + ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), + + ('conv2', nn.Conv2d(32, 64, 3, 1)), + ('prelu2', nn.PReLU(64)), + ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), + + ('conv3', nn.Conv2d(64, 64, 3, 1)), + ('prelu3', nn.PReLU(64)), + ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), + + ('conv4', nn.Conv2d(64, 128, 2, 1)), + ('prelu4', nn.PReLU(128)), + + ('flatten', Flatten()), + ('conv5', nn.Linear(1152, 256)), + ('drop5', nn.Dropout(0.25)), + ('prelu5', nn.PReLU(256)), + ])) + + self.conv6_1 = nn.Linear(256, 2) + self.conv6_2 = nn.Linear(256, 4) + self.conv6_3 = nn.Linear(256, 10) + + weights = np.load(ONET_PATH, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + c: a float tensor with shape [batch_size, 10]. + b: a float tensor with shape [batch_size, 4]. + a: a float tensor with shape [batch_size, 2]. + """ + x = self.features(x) + a = self.conv6_1(x) + b = self.conv6_2(x) + c = self.conv6_3(x) + a = F.softmax(a, dim=-1) + return c, b, a diff --git a/models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py b/models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py new file mode 100644 index 0000000000000000000000000000000000000000..025b18ec2e64472bd4c0c636f9ae061526bdc8cd --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 11 06:54:28 2017 + +@author: zhaoyafei +""" + +import numpy as np +from numpy.linalg import inv, norm, lstsq +from numpy.linalg import matrix_rank as rank + + +class MatlabCp2tormException(Exception): + def __str__(self): + return 'In File {}:{}'.format( + __file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack(( + uv, np.ones((uv.shape[0], 1)) + )) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + # print('--->x, y:\n', x, y + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + # print('--->X.shape: ', X.shape + # print('X:\n', X + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + # print('--->U.shape: ', U.shape + # print('U:\n', U + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + + # print('--->r:\n', r + + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([ + [sc, -ss, 0], + [ss, sc, 0], + [tx, ty, 1] + ]) + + # print('--->Tinv:\n', Tinv + + T = inv(Tinv) + # print('--->T:\n', T + + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([ + [-1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack(( + uv, np.ones((uv.shape[0], 1)) + )) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack(( + xy, np.ones((xy.shape[0], 1)) + )) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py b/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bab02be31a6ca44486f98d57de4ab4bfa89394b7 --- /dev/null +++ b/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py @@ -0,0 +1,31 @@ +from PIL import ImageDraw + + +def show_bboxes(img, bounding_boxes, facial_landmarks=[]): + """Draw bounding boxes and facial landmarks. + + Arguments: + img: an instance of PIL.Image. + bounding_boxes: a float numpy array of shape [n, 5]. + facial_landmarks: a float numpy array of shape [n, 10]. + + Returns: + an instance of PIL.Image. + """ + + img_copy = img.copy() + draw = ImageDraw.Draw(img_copy) + + for b in bounding_boxes: + draw.rectangle([ + (b[0], b[1]), (b[2], b[3]) + ], outline='white') + + for p in facial_landmarks: + for i in range(5): + draw.ellipse([ + (p[i] - 1.0, p[i + 5] - 1.0), + (p[i] + 1.0, p[i + 5] + 1.0) + ], outline='blue') + + return img_copy diff --git a/models/psp.py b/models/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..642c983202a1b728cd9d469c8b0dae675692af32 --- /dev/null +++ b/models/psp.py @@ -0,0 +1,118 @@ +""" +This file defines the core research contribution +""" +import matplotlib +matplotlib.use('Agg') +import math + +import torch +from torch import nn +from models.encoders import psp_encoders +from models.stylegan2.model import Generator +from configs.paths_config import model_paths + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class pSp(nn.Module): + + def __init__(self, opts): + super(pSp, self).__init__() + self.set_opts(opts) + # compute number of style inputs based on the output resolution + self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 + # Define architecture + self.encoder = self.set_encoder() + self.decoder = Generator(self.opts.output_size, 512, 8) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_encoder(self): + if self.opts.encoder_type == 'GradualStyleEncoder': + encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': + encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': + encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts) + else: + raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) + return encoder + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) + self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading encoders weights from irse50!') + encoder_ckpt = torch.load(model_paths['ir_se50']) + # if input to encoder is not an RGB image, do not load the input layer weights + if self.opts.label_nc != 0: + encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k} + self.encoder.load_state_dict(encoder_ckpt, strict=False) + print('Loading decoder weights from pretrained!') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=False) + if self.opts.learn_in_w: + self.__load_latent_avg(ckpt, repeat=1) + else: + self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + if self.opts.learn_in_w: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1) + else: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) + + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + + if return_latents: + result_latent = self.decoder([codes],input_is_latent=input_is_latent,randomize_noise=randomize_noise,return_latents=return_latents) + return result_latent + else: + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + return images + + def set_opts(self, opts): + self.opts = opts + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.opts.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None diff --git a/models/stylegan2/__init__.py b/models/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/stylegan2/model.py b/models/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bf88e51cc1d74b0aeb9337b381b37a3f5d54044a --- /dev/null +++ b/models/stylegan2/model.py @@ -0,0 +1,674 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + if return_latents: + return latent + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/models/stylegan2/op/__init__.py b/models/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2bba1523cacaf56f97433fb2d92f6726b45c5a67 --- /dev/null +++ b/models/stylegan2/op/fused_act.py @@ -0,0 +1,37 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + + +module_path = os.path.dirname(__file__) + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if input.device.type == "cpu": + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) + * scale + ) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f820b4c81e03598589b1ea6b95cf9bef9b04f8 --- /dev/null +++ b/models/stylegan2/op/upfirdn2d.py @@ -0,0 +1,60 @@ +import os + +import torch +from torch.autograd import Function +from torch.nn import functional as F + + + +module_path = os.path.dirname(__file__) + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + + return out.view(-1, channel, out_h, out_w) diff --git a/pretrained/ohayou_face.pkl b/pretrained/ohayou_face.pkl new file mode 100644 index 0000000000000000000000000000000000000000..eab47ac6641035cc1923b57f3d480481fdf476c1 --- /dev/null +++ b/pretrained/ohayou_face.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89abef3962a9ca6b214f1447e7050725b73d41822d7381e1f4d0f96ac8035381 +size 363965331 diff --git a/pretrained/ohayou_face.pt b/pretrained/ohayou_face.pt new file mode 100644 index 0000000000000000000000000000000000000000..d1664584c900218e0917355951361b5ae68cb713 --- /dev/null +++ b/pretrained/ohayou_face.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c63de7970a7af6cc5b5c0cf677eb16095f2aaabd68dab41fcc3851bb5c7464f9 +size 1077486507 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec5b33ac07c32b35734db662bff72e4c9e712ca3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +torch +numpy +torchvision +Pillow +tqdm +imageio +scipy +easydict +opensimplex==0.3 +ninja \ No newline at end of file diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e --- /dev/null +++ b/torch_utils/custom_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import glob +import torch +import torch.utils.cpp_extension +import importlib +import hashlib +import shutil +from pathlib import Path + +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/torch_utils/misc.py b/torch_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..0f158cd871e1df433b018a7658ca24dbddc4ea7c --- /dev/null +++ b/torch_utils/misc.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/__init__.py b/torch_utils/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/torch_utils/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/ops/bias_act.py b/torch_utils/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcb409a89ccf6c6f6ecfca5962683df2d280b1f --- /dev/null +++ b/torch_utils/ops/bias_act.py @@ -0,0 +1,212 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import warnings +import numpy as np +import torch +import dnnlib +import traceback + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d --- /dev/null +++ b/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import warnings +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_resample.py b/torch_utils/ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4750744c83354bab78704d4ef51ad1070fcc4a --- /dev/null +++ b/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/fma.py b/torch_utils/ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..2eeac58a626c49231e04122b93e321ada954c5d3 --- /dev/null +++ b/torch_utils/ops/fma.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6b3413ea72a734703c34382c023b84523601fd --- /dev/null +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import warnings +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + if not enabled: + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + return False + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/upfirdn2d.py b/torch_utils/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ceeac2b9834e33b7c601c28bf27f32aa91c69256 --- /dev/null +++ b/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/persistence.py b/torch_utils/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..76ba3db98086743cdd285500670fddfc6bb42777 --- /dev/null +++ b/torch_utils/persistence.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/training_stats.py b/torch_utils/training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..26f467f9eaa074ee13de1cf2625cd7da44880847 --- /dev/null +++ b/torch_utils/training_stats.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/augment.py b/training/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..3efbf1270a94f08413075c986deeb1570a80f543 --- /dev/null +++ b/training/augment.py @@ -0,0 +1,431 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import scipy.signal +import torch +from torch_utils import persistence +from torch_utils import misc +from torch_utils.ops import upfirdn2d +from torch_utils.ops import grid_sample_gradfix +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + +#---------------------------------------------------------------------------- +# Versatile image augmentation pipeline from the paper +# "Training Generative Adversarial Networks with Limited Data". +# +# All augmentations are disabled by default; individual augmentations can +# be enabled by setting their probability multipliers to 1. + +@persistence.persistent_class +class AugmentPipe(torch.nn.Module): + def __init__(self, + xflip=0, rotate90=0, xint=0, xint_max=0.125, + scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, + imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, + noise=0, cutout=0, noise_std=0.1, cutout_size=0.5, + ): + super().__init__() + self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.xflip = float(xflip) # Probability multiplier for x-flip. + self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float(xint) # Probability multiplier for integer translation. + self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float(xfrac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. + + # Image-space corruptions. + self.noise = float(noise) # Probability multiplier for additive RGB noise. + self.cutout = float(cutout) # Probability multiplier for cutout. + self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. + self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, images, debug_percentile=None): + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply x-flip with probability (xflip * strength). + if self.xflip > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 2) + i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) + + # Apply 90 degree rotations with probability (rotate90 * strength). + if self.rotate90 > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 4) + i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 4)) + G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) + s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) + s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) + + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) + c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. + if self.lumaflip > 0: + i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) + i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max + theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) + s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) + t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape([1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + + # Apply additive RGB noise with probability (noise * strength). + if self.noise > 0: + sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std + sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) + if debug_percentile is not None: + sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std) + images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma + + # Apply cutout with probability (cutout * strength). + if self.cutout > 0: + size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) + size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) + center = torch.rand([batch_size, 2, 1, 1, 1], device=device) + if debug_percentile is not None: + size = torch.full_like(size, self.cutout_size) + center = torch.full_like(center, debug_percentile) + coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) + coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) + mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) + mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) + mask = torch.logical_or(mask_x, mask_y).to(torch.float32) + images = images * mask + + return images + +#---------------------------------------------------------------------------- diff --git a/training/coach.py b/training/coach.py new file mode 100644 index 0000000000000000000000000000000000000000..63efdb5146154912e0880dab4a65216520a4e9c6 --- /dev/null +++ b/training/coach.py @@ -0,0 +1,290 @@ +import os +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.use('Agg') + +import torch +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F + +from utils import common, train_utils +from criteria import id_loss, w_norm, moco_loss +from configs import data_configs +from datasets.images_dataset import ImagesDataset +from criteria.lpips.lpips import LPIPS +from models.psp import pSp +from training.ranger import Ranger + + +class Coach: + def __init__(self, opts): + self.opts = opts + + self.global_step = 0 + + self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES + self.opts.device = self.device + + if self.opts.use_wandb: + from utils.wandb_utils import WBLogger + self.wb_logger = WBLogger(self.opts) + + # Initialize network + self.net = pSp(self.opts).to(self.device) + + # Estimate latent_avg via dense sampling if latent_avg is not available + if self.net.latent_avg is None: + self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0].detach() + + # Initialize loss + if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0: + raise ValueError('Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!') + + self.mse_loss = nn.MSELoss().to(self.device).eval() + if self.opts.lpips_lambda > 0: + self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() + if self.opts.id_lambda > 0: + self.id_loss = id_loss.IDLoss().to(self.device).eval() + if self.opts.w_norm_lambda > 0: + self.w_norm_loss = w_norm.WNormLoss(start_from_latent_avg=self.opts.start_from_latent_avg) + if self.opts.moco_lambda > 0: + self.moco_loss = moco_loss.MocoLoss().to(self.device).eval() + + # Initialize optimizer + self.optimizer = self.configure_optimizers() + + # Initialize dataset + self.train_dataset, self.test_dataset = self.configure_datasets() + self.train_dataloader = DataLoader(self.train_dataset, + batch_size=self.opts.batch_size, + shuffle=True, + num_workers=int(self.opts.workers), + drop_last=True) + self.test_dataloader = DataLoader(self.test_dataset, + batch_size=self.opts.test_batch_size, + shuffle=False, + num_workers=int(self.opts.test_workers), + drop_last=True) + + # Initialize logger + log_dir = os.path.join(opts.exp_dir, 'logs') + os.makedirs(log_dir, exist_ok=True) + self.logger = SummaryWriter(log_dir=log_dir) + + # Initialize checkpoint dir + self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_val_loss = None + if self.opts.save_interval is None: + self.opts.save_interval = self.opts.max_steps + + def train(self): + self.net.train() + while self.global_step < self.opts.max_steps: + for batch_idx, batch in enumerate(self.train_dataloader): + self.optimizer.zero_grad() + x, y = batch + x, y = x.to(self.device).float(), y.to(self.device).float() + y_hat, latent = self.net.forward(x, return_latents=True) + loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) + loss.backward() + self.optimizer.step() + + # Logging related + if self.global_step % self.opts.image_interval == 0 or (self.global_step < 1000 and self.global_step % 25 == 0): + self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces') + if self.global_step % self.opts.board_interval == 0: + self.print_metrics(loss_dict, prefix='train') + self.log_metrics(loss_dict, prefix='train') + + # Log images of first batch to wandb + if self.opts.use_wandb and batch_idx == 0: + self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="train", step=self.global_step, opts=self.opts) + + # Validation related + val_loss_dict = None + if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: + val_loss_dict = self.validate() + if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): + self.best_val_loss = val_loss_dict['loss'] + self.checkpoint_me(val_loss_dict, is_best=True) + + if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: + if val_loss_dict is not None: + self.checkpoint_me(val_loss_dict, is_best=False) + else: + self.checkpoint_me(loss_dict, is_best=False) + + if self.global_step == self.opts.max_steps: + print('OMG, finished training!') + break + + self.global_step += 1 + + def validate(self): + self.net.eval() + agg_loss_dict = [] + for batch_idx, batch in enumerate(self.test_dataloader): + x, y = batch + + with torch.no_grad(): + x, y = x.to(self.device).float(), y.to(self.device).float() + y_hat, latent = self.net.forward(x, return_latents=True) + loss, cur_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) + agg_loss_dict.append(cur_loss_dict) + + # Logging related + self.parse_and_log_images(id_logs, x, y, y_hat, + title='images/test/faces', + subscript='{:04d}'.format(batch_idx)) + + # Log images of first batch to wandb + if self.opts.use_wandb and batch_idx == 0: + self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="test", step=self.global_step, opts=self.opts) + + # For first step just do sanity test on small amount of data + if self.global_step == 0 and batch_idx >= 4: + self.net.train() + return None # Do not log, inaccurate in first batch + + loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) + self.log_metrics(loss_dict, prefix='test') + self.print_metrics(loss_dict, prefix='test') + + self.net.train() + return loss_dict + + def checkpoint_me(self, loss_dict, is_best): + save_name = 'best_model.pt' if is_best else f'iteration_{self.global_step}.pt' + save_dict = self.__get_save_dict() + checkpoint_path = os.path.join(self.checkpoint_dir, save_name) + torch.save(save_dict, checkpoint_path) + with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: + if is_best: + f.write(f'**Best**: Step - {self.global_step}, Loss - {self.best_val_loss} \n{loss_dict}\n') + if self.opts.use_wandb: + self.wb_logger.log_best_model() + else: + f.write(f'Step - {self.global_step}, \n{loss_dict}\n') + + def configure_optimizers(self): + params = list(self.net.encoder.parameters()) + if self.opts.train_decoder: + params += list(self.net.decoder.parameters()) + if self.opts.optim_name == 'adam': + optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) + else: + optimizer = Ranger(params, lr=self.opts.learning_rate) + return optimizer + + def configure_datasets(self): + if self.opts.dataset_type not in data_configs.DATASETS.keys(): + Exception(f'{self.opts.dataset_type} is not a valid dataset_type') + print(f'Loading dataset for {self.opts.dataset_type}') + dataset_args = data_configs.DATASETS[self.opts.dataset_type] + transforms_dict = dataset_args['transforms'](self.opts).get_transforms() + train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'], + target_root=dataset_args['train_target_root'], + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_gt_train'], + opts=self.opts) + test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'], + target_root=dataset_args['test_target_root'], + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_test'], + opts=self.opts) + if self.opts.use_wandb: + self.wb_logger.log_dataset_wandb(train_dataset, dataset_name="Train") + self.wb_logger.log_dataset_wandb(test_dataset, dataset_name="Test") + print(f"Number of training samples: {len(train_dataset)}") + print(f"Number of test samples: {len(test_dataset)}") + return train_dataset, test_dataset + + def calc_loss(self, x, y, y_hat, latent): + loss_dict = {} + loss = 0.0 + id_logs = None + if self.opts.id_lambda > 0: + loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x) + loss_dict['loss_id'] = float(loss_id) + loss_dict['id_improve'] = float(sim_improvement) + loss = loss_id * self.opts.id_lambda + if self.opts.l2_lambda > 0: + loss_l2 = F.mse_loss(y_hat, y) + loss_dict['loss_l2'] = float(loss_l2) + loss += loss_l2 * self.opts.l2_lambda + if self.opts.lpips_lambda > 0: + loss_lpips = self.lpips_loss(y_hat, y) + loss_dict['loss_lpips'] = float(loss_lpips) + loss += loss_lpips * self.opts.lpips_lambda + if self.opts.lpips_lambda_crop > 0: + loss_lpips_crop = self.lpips_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220]) + loss_dict['loss_lpips_crop'] = float(loss_lpips_crop) + loss += loss_lpips_crop * self.opts.lpips_lambda_crop + if self.opts.l2_lambda_crop > 0: + loss_l2_crop = F.mse_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220]) + loss_dict['loss_l2_crop'] = float(loss_l2_crop) + loss += loss_l2_crop * self.opts.l2_lambda_crop + if self.opts.w_norm_lambda > 0: + loss_w_norm = self.w_norm_loss(latent, self.net.latent_avg) + loss_dict['loss_w_norm'] = float(loss_w_norm) + loss += loss_w_norm * self.opts.w_norm_lambda + if self.opts.moco_lambda > 0: + loss_moco, sim_improvement, id_logs = self.moco_loss(y_hat, y, x) + loss_dict['loss_moco'] = float(loss_moco) + loss_dict['id_improve'] = float(sim_improvement) + loss += loss_moco * self.opts.moco_lambda + + loss_dict['loss'] = float(loss) + return loss, loss_dict, id_logs + + def log_metrics(self, metrics_dict, prefix): + for key, value in metrics_dict.items(): + self.logger.add_scalar(f'{prefix}/{key}', value, self.global_step) + if self.opts.use_wandb: + self.wb_logger.log(prefix, metrics_dict, self.global_step) + + def print_metrics(self, metrics_dict, prefix): + print(f'Metrics for {prefix}, step {self.global_step}') + for key, value in metrics_dict.items(): + print(f'\t{key} = ', value) + + def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2): + im_data = [] + for i in range(display_count): + cur_im_data = { + 'input_face': common.log_input_image(x[i], self.opts), + 'target_face': common.tensor2im(y[i]), + 'output_face': common.tensor2im(y_hat[i]), + } + if id_logs is not None: + for key in id_logs[i]: + cur_im_data[key] = id_logs[i][key] + im_data.append(cur_im_data) + self.log_images(title, im_data=im_data, subscript=subscript) + + def log_images(self, name, im_data, subscript=None, log_latest=False): + fig = common.vis_faces(im_data) + step = self.global_step + if log_latest: + step = 0 + if subscript: + path = os.path.join(self.logger.log_dir, name, f'{subscript}_{step:04d}.jpg') + else: + path = os.path.join(self.logger.log_dir, name, f'{step:04d}.jpg') + os.makedirs(os.path.dirname(path), exist_ok=True) + fig.savefig(path) + plt.close(fig) + + def __get_save_dict(self): + save_dict = { + 'state_dict': self.net.state_dict(), + 'opts': vars(self.opts) + } + # save the latent avg in state_dict for inference if truncation of w was used during training + if self.opts.start_from_latent_avg: + save_dict['latent_avg'] = self.net.latent_avg + return save_dict diff --git a/training/dataset.py b/training/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..18540c3c100004d637ca51740a179e690ce5f352 --- /dev/null +++ b/training/dataset.py @@ -0,0 +1,248 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import numpy as np +import zipfile +import PIL.Image +import json +import torch +import dnnlib + +try: + import pyspng +except ImportError: + pyspng = None + +#---------------------------------------------------------------------------- + +class Dataset(torch.utils.data.Dataset): + def __init__(self, + name, # Name of the dataset. + raw_shape, # Shape of the raw image data (NCHW). + max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. + use_labels = False, # Enable conditioning labels? False = label dimension is zero. + xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. + yflip = False, # Apply mirror augment vertically? + random_seed = 0, # Random seed to use when applying max_size. + ): + self._name = name + self._raw_shape = list(raw_shape) + self._use_labels = use_labels + self._raw_labels = None + self._label_shape = None + + # Apply max_size. + self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) + if (max_size is not None) and (self._raw_idx.size > max_size): + np.random.RandomState(random_seed).shuffle(self._raw_idx) + self._raw_idx = np.sort(self._raw_idx[:max_size]) + + # Apply xflip. + self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if xflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) + + # Apply yflip. + self._yflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if yflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._yflip = np.concatenate([self._yflip, np.ones_like(self._yflip)]) + self._xflip = np.tile(self._xflip, 2) # double the indices for xflip, otherwise we get out of bounds + + def _get_raw_labels(self): + if self._raw_labels is None: + self._raw_labels = self._load_raw_labels() if self._use_labels else None + if self._raw_labels is None: + self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) + assert isinstance(self._raw_labels, np.ndarray) + assert self._raw_labels.shape[0] == self._raw_shape[0] + assert self._raw_labels.dtype in [np.float32, np.int64] + if self._raw_labels.dtype == np.int64: + assert self._raw_labels.ndim == 1 + assert np.all(self._raw_labels >= 0) + return self._raw_labels + + def close(self): # to be overridden by subclass + pass + + def _load_raw_image(self, raw_idx): # to be overridden by subclass + raise NotImplementedError + + def _load_raw_labels(self): # to be overridden by subclass + raise NotImplementedError + + def __getstate__(self): + return dict(self.__dict__, _raw_labels=None) + + def __del__(self): + try: + self.close() + except: + pass + + def __len__(self): + return self._raw_idx.size + + def __getitem__(self, idx): + image = self._load_raw_image(self._raw_idx[idx]) + assert isinstance(image, np.ndarray) + assert list(image.shape) == self.image_shape + assert image.dtype == np.uint8 + if self._xflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, :, ::-1] + if self._yflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, ::-1, :] + return image.copy(), self.get_label(idx) + + def get_label(self, idx): + label = self._get_raw_labels()[self._raw_idx[idx]] + if label.dtype == np.int64: + onehot = np.zeros(self.label_shape, dtype=np.float32) + onehot[label] = 1 + label = onehot + return label.copy() + + def get_details(self, idx): + d = dnnlib.EasyDict() + d.raw_idx = int(self._raw_idx[idx]) + d.xflip = (int(self._xflip[idx]) != 0) + d.yflip = (int(self._yflip[idx]) != 0) + d.raw_label = self._get_raw_labels()[d.raw_idx].copy() + return d + + @property + def name(self): + return self._name + + @property + def image_shape(self): + return list(self._raw_shape[1:]) + + @property + def num_channels(self): + assert len(self.image_shape) == 3 # CHW + return self.image_shape[0] + + @property + def resolution(self): + assert len(self.image_shape) == 3 # CHW + assert self.image_shape[1] == self.image_shape[2] + return self.image_shape[1] + + @property + def label_shape(self): + if self._label_shape is None: + raw_labels = self._get_raw_labels() + if raw_labels.dtype == np.int64: + self._label_shape = [int(np.max(raw_labels)) + 1] + else: + self._label_shape = raw_labels.shape[1:] + return list(self._label_shape) + + @property + def label_dim(self): + assert len(self.label_shape) == 1 + return self.label_shape[0] + + @property + def has_labels(self): + return any(x != 0 for x in self.label_shape) + + @property + def has_onehot_labels(self): + return self._get_raw_labels().dtype == np.int64 + +#---------------------------------------------------------------------------- + +class ImageFolderDataset(Dataset): + def __init__(self, + path, # Path to directory or zip. + resolution = None, # Ensure specific resolution, None = highest available. + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self._path = path + self._zipfile = None + + if os.path.isdir(self._path): + self._type = 'dir' + self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} + elif self._file_ext(self._path) == '.zip': + self._type = 'zip' + self._all_fnames = set(self._get_zipfile().namelist()) + else: + raise IOError('Path must point to a directory or zip') + + PIL.Image.init() + self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) + if len(self._image_fnames) == 0: + raise IOError('No image files found in the specified path') + + name = os.path.splitext(os.path.basename(self._path))[0] + raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) + if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): + raise IOError('Image files do not match the specified resolution') + super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) + + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def _get_zipfile(self): + assert self._type == 'zip' + if self._zipfile is None: + self._zipfile = zipfile.ZipFile(self._path) + return self._zipfile + + def _open_file(self, fname): + if self._type == 'dir': + return open(os.path.join(self._path, fname), 'rb') + if self._type == 'zip': + return self._get_zipfile().open(fname, 'r') + return None + + def close(self): + try: + if self._zipfile is not None: + self._zipfile.close() + finally: + self._zipfile = None + + def __getstate__(self): + return dict(super().__getstate__(), _zipfile=None) + + def _load_raw_image(self, raw_idx): + fname = self._image_fnames[raw_idx] + with self._open_file(fname) as f: + if pyspng is not None and self._file_ext(fname) == '.png': + image = pyspng.load(f.read()) + else: + image = np.array(PIL.Image.open(f)) + if image.ndim == 2: + image = image[:, :, np.newaxis] # HW => HWC + image = image.transpose(2, 0, 1) # HWC => CHW + return image + + def _load_raw_labels(self): + fname = 'dataset.json' + if fname not in self._all_fnames: + return None + with self._open_file(fname) as f: + labels = json.load(f)['labels'] + if labels is None: + return None + labels = dict(labels) + labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] + labels = np.array(labels) + labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) + return labels + +#---------------------------------------------------------------------------- diff --git a/training/loss.py b/training/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5299e84a619eea15aaedbe05d8753522b358c720 --- /dev/null +++ b/training/loss.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import torch +from torch_utils import training_stats +from torch_utils import misc +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- + +class Loss: + def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass + raise NotImplementedError() + +#---------------------------------------------------------------------------- + +class StyleGAN2Loss(Loss): + def __init__(self, device, G, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, G_top_k = False, G_top_k_gamma = 0.9, G_top_k_frac = 0.5,): + super().__init__() + self.device = device + self.G = G + self.G_mapping = G_mapping + self.G_synthesis = G_synthesis + self.D = D + self.augment_pipe = augment_pipe + self.style_mixing_prob = style_mixing_prob + self.r1_gamma = r1_gamma + self.pl_batch_shrink = pl_batch_shrink + self.pl_decay = pl_decay + self.pl_weight = pl_weight + self.pl_mean = torch.zeros([], device=device) + self.G_top_k = G_top_k + self.G_top_k_gamma = G_top_k_gamma + self.G_top_k_frac = G_top_k_frac + + + def run_G(self, z, c, sync): + with misc.ddp_sync(self.G_mapping, sync): + ws = self.G_mapping(z, c) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] + with misc.ddp_sync(self.G_synthesis, sync): + img = self.G_synthesis(ws) + return img, ws + + def run_D(self, img, c, sync): + if self.augment_pipe is not None: + img = self.augment_pipe(img) + with misc.ddp_sync(self.D, sync): + logits = self.D(img, c) + return logits + + def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): + assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] + do_Gmain = (phase in ['Gmain', 'Gboth']) + do_Dmain = (phase in ['Dmain', 'Dboth']) + do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) + do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) + + # Gmain: Maximize logits for generated images. + if do_Gmain: + with torch.autograd.profiler.record_function('Gmain_forward'): + minibatch_size = gen_z.shape[0] + gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl. + gen_logits = self.run_D(gen_img, gen_c, sync=False) + training_stats.report('Loss/scores/fake', gen_logits) + training_stats.report('Loss/signs/fake', gen_logits.sign()) + + # top-k function based on: https://github.com/dvschultz/stylegan2-ada/blob/main/training/loss.py#L102 + if self.G_top_k: + D_fake_scores = gen_logits + k_frac = np.maximum(self.G_top_k_gamma ** self.G.epochs, self.G_top_k_frac) + k = int(np.ceil(minibatch_size * k_frac)) + lowest_k_scores, _ = torch.topk(-torch.squeeze(D_fake_scores), k=k) # want smallest probabilities not largest + gen_logits = torch.unsqueeze(-lowest_k_scores, axis=1) + + loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) + training_stats.report('Loss/G/loss', loss_Gmain) + with torch.autograd.profiler.record_function('Gmain_backward'): + loss_Gmain.mean().mul(gain).backward() + + # Gpl: Apply path length regularization. + if do_Gpl: + with torch.autograd.profiler.record_function('Gpl_forward'): + batch_size = gen_z.shape[0] // self.pl_batch_shrink + gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) + pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) + with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): + pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] + pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() + pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) + self.pl_mean.copy_(pl_mean.detach()) + pl_penalty = (pl_lengths - pl_mean).square() + training_stats.report('Loss/pl_penalty', pl_penalty) + loss_Gpl = pl_penalty * self.pl_weight + training_stats.report('Loss/G/reg', loss_Gpl) + with torch.autograd.profiler.record_function('Gpl_backward'): + (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() + + # Dmain: Minimize logits for generated images. + loss_Dgen = 0 + if do_Dmain: + with torch.autograd.profiler.record_function('Dgen_forward'): + gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) + gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. + training_stats.report('Loss/scores/fake', gen_logits) + training_stats.report('Loss/signs/fake', gen_logits.sign()) + loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) + with torch.autograd.profiler.record_function('Dgen_backward'): + loss_Dgen.mean().mul(gain).backward() + + # Dmain: Maximize logits for real images. + # Dr1: Apply R1 regularization. + if do_Dmain or do_Dr1: + name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' + with torch.autograd.profiler.record_function(name + '_forward'): + real_img_tmp = real_img.detach().requires_grad_(do_Dr1) + real_logits = self.run_D(real_img_tmp, real_c, sync=sync) + training_stats.report('Loss/scores/real', real_logits) + training_stats.report('Loss/signs/real', real_logits.sign()) + + loss_Dreal = 0 + if do_Dmain: + loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) + training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) + + loss_Dr1 = 0 + if do_Dr1: + with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): + r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] + r1_penalty = r1_grads.square().sum([1,2,3]) + loss_Dr1 = r1_penalty * (self.r1_gamma / 2) + training_stats.report('Loss/r1_penalty', r1_penalty) + training_stats.report('Loss/D/reg', loss_Dr1) + + with torch.autograd.profiler.record_function(name + '_backward'): + (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward() + +#---------------------------------------------------------------------------- diff --git a/training/networks.py b/training/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..694e4dbc9c8da3dbbddeb0d0efd349838b7f9d94 --- /dev/null +++ b/training/networks.py @@ -0,0 +1,735 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise = None, # Optional noise tensor to add to the output activations. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + padding = 0, # Padding with respect to the upsampled image. + resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate = True, # Apply weight demodulation? + flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output to +-X, None = disable clamping. + channels_last = False, # Expect the input to have memory_format=channels_last? + trainable = True, # Update the weights of this layer during training? + ): + super().__init__() + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers = 8, # Number of mapping layers. + embed_features = None, # Label embedding dimensionality, None = same as w_dim. + layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if self.w_avg_beta is not None and self.training and not skip_w_avg_update: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size = 3, # Convolution kernel size. + up = 1, # Integer upsampling factor. + use_noise = True, # Enable noise input? + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last = False, # Use channels_last format for the weights? + ): + super().__init__() + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): + super().__init__() + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, + resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 0, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, **block_kwargs) + return img + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + synthesis_kwargs = {}, # Arguments for SynthesisNetwork. + epochs = 0., # Track epoch count for top-k + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + self.epochs = 0. + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) + img = self.synthesis(ws, **synthesis_kwargs) + return img + + def update_epochs(self, epoch): + self.epochs = epoch + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + freeze_layers = 0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.num_layers = 0 + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = (layer_idx >= freeze_layers) + self.num_layers += 1 + yield trainable + trainable_iter = trainable_gen() + + if in_channels == 0 or architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) + + self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) + + self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) + + if architecture == 'resnet': + self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, force_fp32=False): + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + + # Input. + if x is not None: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0 or self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + + # Main layers. + if self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + x = self.conv1(x) + + assert x.dtype == dtype + return x, img + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, x, img, cmap, force_fp32=False): + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + if self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + x = x + self.fromrgb(img) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + x = self.fc(x.flatten(1)) + x = self.out(x) + + # Conditioning. + if self.cmap_dim > 0: + misc.assert_shape(cmap, [None, self.cmap_dim]) + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 0, # Use FP16 for the N highest resolutions. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + self.epochs = 0. # top-k setting + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + + def forward(self, img, c, **block_kwargs): + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + +#---------------------------------------------------------------------------- diff --git a/training/ranger.py b/training/ranger.py new file mode 100644 index 0000000000000000000000000000000000000000..3d63264dda6df0ee40cac143440f0b5f8977a9ad --- /dev/null +++ b/training/ranger.py @@ -0,0 +1,164 @@ +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. + +import math +import torch +from torch.optim.optimizer import Optimizer + + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + def __setstate__(self, state): + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss \ No newline at end of file diff --git a/training/stylegan2_multi.py b/training/stylegan2_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..23b003d9003c47c0095ea00b02a0f6e1c987a789 --- /dev/null +++ b/training/stylegan2_multi.py @@ -0,0 +1,414 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + +from .networks import FullyConnectedLayer, Conv2dLayer, ToRGBLayer, MappingNetwork + +from util.utilgan import hw_scales, fix_size, multimask + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. +# !!! custom + # latmask, # mask for split-frame latents blending + countHW = [1,1], # frame split count by height,width + splitfine = 0., # frame split edge fineness (float from 0+) + size = None, # custom size + scale_type = None, # scaling way: fit, centr, side, pad, padside + noise = None, # Optional noise tensor to add to the output activations. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + padding = 0, # Padding with respect to the upsampled image. + resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate = True, # Apply weight demodulation? + flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) +# !!! custom size & multi latent blending + if size is not None and up==2: + x = fix_size(x, size, scale_type) + # x = multimask(x, size, latmask, countHW, splitfine) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) +# !!! custom size & multi latent blending + if size is not None and up==2: + x = fix_size(x, size, scale_type) + # x = multimask(x, size, latmask, countHW, splitfine) + if noise is not None: + x = x.add_(noise) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. +# !!! custom + countHW = [1,1], # frame split count by height,width + splitfine = 0., # frame split edge fineness (float from 0+) + size = None, # custom size + scale_type = None, # scaling way: fit, centr, side, pad, padside + init_res = [4,4], # Initial (minimal) resolution for progressive training + kernel_size = 3, # Convolution kernel size. + up = 1, # Integer upsampling factor. + use_noise = True, # Enable noise input? + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last = False, # Use channels_last format for the weights? + ): + super().__init__() + self.resolution = resolution + self.countHW = countHW # !!! custom + self.splitfine = splitfine # !!! custom + self.size = size # !!! custom + self.scale_type = scale_type # !!! custom + self.init_res = init_res # !!! custom + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: +# !!! custom + self.register_buffer('noise_const', torch.randn([resolution * init_res[0]//4, resolution * init_res[1]//4])) + # self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + +# !!! custom + # def forward(self, x, latmask, w, noise_mode='random', fused_modconv=True, gain=1): + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + # misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == 'random': +# !!! custom + sz = self.size if self.up==2 and self.size is not None else x.shape[2:] + noise = torch.randn([x.shape[0], 1, *sz], device=x.device) * self.noise_strength + # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength +# !!! custom noise size + noise_size = self.size if self.up==2 and self.size is not None and self.resolution > 4 else x.shape[2:] + noise = fix_size(noise.unsqueeze(0).unsqueeze(0), noise_size, scale_type=self.scale_type)[0][0] + + # print(x.shape, noise.shape, self.size, self.up) + + flip_weight = (self.up == 1) # slightly faster + # x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + # latmask=latmask, countHW=self.countHW, splitfine=self.splitfine, size=self.size, scale_type=self.scale_type, # !!! custom + # padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) + + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + countHW=self.countHW, splitfine=self.splitfine, size=self.size, scale_type=self.scale_type, # !!! custom + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) + + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? +# !!! custom + size = None, # custom size + scale_type = None, # scaling way: fit, centr, side, pad, padside + init_res = [4,4], # Initial (minimal) resolution for progressive training + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.size = size # !!! custom + self.scale_type = scale_type # !!! custom + self.init_res = init_res # !!! custom + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: +# !!! custom + self.const = torch.nn.Parameter(torch.randn([out_channels, *init_res])) + # self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, + init_res=init_res, scale_type=scale_type, size=size, # !!! custom + resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + init_res=init_res, scale_type=scale_type, size=size, # !!! custom + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + +# !!! custom + # def forward(self, x, img, ws, latmask, dconst, force_fp32=False, fused_modconv=None, **layer_kwargs): + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) +# !!! custom const size + if 'side' in self.scale_type and 'symm' in self.scale_type: # looks better + const_size = self.init_res if self.size is None else self.size + x = fix_size(x, const_size, self.scale_type) +# distortion technique from Aydao + # x += dconst + else: + # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: +# !!! custom latmask + # x = self.conv1(x, None, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) +# !!! custom latmask + # x = self.conv0(x, latmask, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + # x = self.conv1(x, None, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: +# !!! custom latmask + # x = self.conv0(x, latmask, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + # x = self.conv1(x, None, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: +# !!! custom img size + # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + img = fix_size(img, self.size, scale_type=self.scale_type) + + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. +# !!! custom + init_res = [4,4], # Initial (minimal) resolution for progressive training + size = None, # Output size + scale_type = None, # scaling way: fit, centr, side, pad, padside + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 0, # Use FP16 for the N highest resolutions. + verbose = False, # + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.res_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.fmap_base = channel_base + self.block_resolutions = [2 ** i for i in range(2, self.res_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.res_log2 + 1 - num_fp16_res), 8) + + # calculate intermediate layers sizes for arbitrary output resolution + custom_res = (img_resolution * init_res[0] // 4, img_resolution * init_res[1] // 4) + if size is None: size = custom_res + if init_res != [4,4] and verbose: + print(' .. init res', init_res, size) + keep_first_layers = 2 if scale_type == 'fit' else None + hws = hw_scales(size, custom_res, self.res_log2 - 2, keep_first_layers, verbose) + if verbose: print(hws, '..', custom_res, self.res_log2-1) + + self.num_ws = 0 + for i, res in enumerate(self.block_resolutions): + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + init_res=init_res, scale_type=scale_type, size=hws[i], # !!! custom + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + # def forward(self, ws, latmask, dconst, **block_kwargs): + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') +# !!! custom + # x, img = block(x, img, cur_ws, latmask, dconst, **block_kwargs) + x, img = block(x, img, cur_ws, **block_kwargs) + return img + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. +# !!! custom + init_res = [4,4], # Initial (minimal) resolution for progressive training + mapping_kwargs = {}, # Arguments for MappingNetwork. + synthesis_kwargs = {}, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.init_res = init_res # !!! custom + self.img_channels = img_channels +# !!! custom + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, init_res=init_res, img_channels=img_channels, **synthesis_kwargs) # !!! custom + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) +# !!! custom + self.output_shape = [1, img_channels, img_resolution * init_res[0] // 4, img_resolution * init_res[1] // 4] + +# !!! custom + # def forward(self, z, c, latmask, dconst, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): + # def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) + # img = self.synthesis(ws, latmask, dconst, **synthesis_kwargs) # !!! custom + img = self.synthesis(ws, **synthesis_kwargs) # !!! custom + return img diff --git a/training/training_loop.py b/training/training_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..538548057344083d3f14652e6ef843ce001d55ec --- /dev/null +++ b/training/training_loop.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import time +import copy +import json +import pickle +import psutil +import PIL.Image +import numpy as np +import torch +import dnnlib +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +import legacy +from metrics import metric_main + +#---------------------------------------------------------------------------- + +def setup_snapshot_image_grid(training_set, random_seed=0): + rnd = np.random.RandomState(random_seed) + gw = np.clip(7680 // training_set.image_shape[2], 7, 32) + gh = np.clip(4320 // training_set.image_shape[1], 4, 32) + + # No labels => show random subset of training samples. + if not training_set.has_labels: + all_indices = list(range(len(training_set))) + rnd.shuffle(all_indices) + grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] + + else: + # Group training samples by label. + label_groups = dict() # label => [idx, ...] + for idx in range(len(training_set)): + label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) + if label not in label_groups: + label_groups[label] = [] + label_groups[label].append(idx) + + # Reorder. + label_order = sorted(label_groups.keys()) + for label in label_order: + rnd.shuffle(label_groups[label]) + + # Organize into grid. + grid_indices = [] + for y in range(gh): + label = label_order[y % len(label_order)] + indices = label_groups[label] + grid_indices += [indices[x % len(indices)] for x in range(gw)] + label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] + + # Load data. + images, labels = zip(*[training_set[i] for i in grid_indices]) + return (gw, gh), np.stack(images), np.stack(labels) + +#---------------------------------------------------------------------------- + +def save_image_grid(img, fname, drange, grid_size): + lo, hi = drange + img = np.asarray(img, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = np.rint(img).clip(0, 255).astype(np.uint8) + + gw, gh = grid_size + _N, C, H, W = img.shape + img = img.reshape(gh, gw, C, H, W) + img = img.transpose(0, 3, 1, 4, 2) + img = img.reshape(gh * H, gw * W, C) + + assert C in [1, 3] + if C == 1: + PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) + if C == 3: + PIL.Image.fromarray(img, 'RGB').save(fname) + +#---------------------------------------------------------------------------- + +def training_loop( + run_dir = '.', # Output directory. + training_set_kwargs = {}, # Options for training set. + data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. + G_kwargs = {}, # Options for generator network. + D_kwargs = {}, # Options for discriminator network. + G_opt_kwargs = {}, # Options for generator optimizer. + D_opt_kwargs = {}, # Options for discriminator optimizer. + augment_kwargs = None, # Options for augmentation pipeline. None = disable. + loss_kwargs = {}, # Options for loss function. + metrics = [], # Metrics to evaluate during training. + random_seed = 0, # Global random seed. + num_gpus = 1, # Number of GPUs participating in the training. + rank = 0, # Rank of the current process in [0, num_gpus[. + batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. + batch_gpu = 4, # Number of samples processed at a time by one GPU. + ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. + ema_rampup = None, # EMA ramp-up coefficient. + G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization. + D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. + augment_p = 0, # Initial value of augmentation probability. + ada_target = None, # ADA target value. None = fixed p. + ada_interval = 4, # How often to perform ADA adjustment? + ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. + nimg = 0, # current image count + total_kimg = 25000, # Total length of the training, measured in thousands of real images. + kimg_per_tick = 4, # Progress snapshot interval. + image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. + network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. + resume_pkl = None, # Network pickle to resume training from. + cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? + allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? + abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. + progress_fn = None, # Callback function for updating training progress. Called for all ranks. +): + # Initialize. + start_time = time.time() + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul + torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Load training set. + if rank == 0: + print('Loading training set...') + training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset + training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) + training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) + if rank == 0: + print() + print('Num images: ', len(training_set)) + print('Image shape:', training_set.image_shape) + print('Label shape:', training_set.label_shape) + print() + + # Construct networks. + if rank == 0: + print('Constructing networks...') + common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) + G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + G_ema = copy.deepcopy(G).eval() + + G.update_epochs( float(100 * nimg / (total_kimg * 1000)) ) # 100 total top k "epochs" in total_kimg + print('starting G epochs: ',G.epochs) + + # Resume from existing pickle. + if (resume_pkl is not None) and (rank == 0): + print(f'Resuming from "{resume_pkl}"') + with dnnlib.util.open_url(resume_pkl) as f: + resume_data = legacy.load_network_pkl(f) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: + misc.copy_params_and_buffers(resume_data[name], module, require_all=False) + + # Print network summary tables. + if rank == 0: + z = torch.empty([batch_gpu, G.z_dim], device=device) + c = torch.empty([batch_gpu, G.c_dim], device=device) + img = misc.print_module_summary(G, [z, c]) + misc.print_module_summary(D, [img, c]) + + # Setup augmentation. + if rank == 0: + print('Setting up augmentation...') + augment_pipe = None + ada_stats = None + if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): + augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + augment_pipe.p.copy_(torch.as_tensor(augment_p)) + if ada_target is not None: + ada_stats = training_stats.Collector(regex='Loss/signs/real') + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + ddp_modules = dict() + for name, module in [('G', G),('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: + if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0: + module.requires_grad_(True) + module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False) + module.requires_grad_(False) + if name is not None: + ddp_modules[name] = module + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss + phases = [] + for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: + if reg_interval is None: + opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] + else: # Lazy regularization. + mb_ratio = reg_interval / (reg_interval + 1) + opt_kwargs = dnnlib.EasyDict(opt_kwargs) + opt_kwargs.lr = opt_kwargs.lr * mb_ratio + opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] + opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] + phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] + for phase in phases: + phase.start_event = None + phase.end_event = None + if rank == 0: + phase.start_event = torch.cuda.Event(enable_timing=True) + phase.end_event = torch.cuda.Event(enable_timing=True) + + # Export sample images. + grid_size = None + grid_z = None + grid_c = None + if rank == 0: + print('Exporting sample images...') + grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) + save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size) + grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) + grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) + images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() + save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + stats_collector = training_stats.Collector(regex='.*') + stats_metrics = dict() + stats_jsonl = None + stats_tfevents = None + if rank == 0: + stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') + try: + import torch.utils.tensorboard as tensorboard + stats_tfevents = tensorboard.SummaryWriter(run_dir) + except ImportError as err: + print('Skipping tfevents export:', err) + + # Train. + if rank == 0: + print(f'Training for {total_kimg} kimg...') + print() + cur_nimg = nimg + cur_tick = 0 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - start_time + batch_idx = 0 + if progress_fn is not None: + progress_fn(0, total_kimg) + while True: + + # Fetch training data. + with torch.autograd.profiler.record_function('data_fetch'): + phase_real_img, phase_real_c = next(training_set_iterator) + phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + phase_real_c = phase_real_c.to(device).split(batch_gpu) + all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) + all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] + all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] + all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) + all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] + + # Execute training phases. + for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): + if batch_idx % phase.interval != 0: + continue + + G.update_epochs( float(100 * nimg / (total_kimg * 1000)) ) # 100 total top k "epochs" in total_kimg + + # Initialize gradient accumulation. + if phase.start_event is not None: + phase.start_event.record(torch.cuda.current_stream(device)) + phase.opt.zero_grad(set_to_none=True) + phase.module.requires_grad_(True) + + # Accumulate gradients over multiple rounds. + for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): + sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) + gain = phase.interval + loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) + + # Update weights. + phase.module.requires_grad_(False) + with torch.autograd.profiler.record_function(phase.name + '_opt'): + for param in phase.module.parameters(): + if param.grad is not None: + misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) + phase.opt.step() + if phase.end_event is not None: + phase.end_event.record(torch.cuda.current_stream(device)) + + # Update G_ema. + with torch.autograd.profiler.record_function('Gema'): + ema_nimg = ema_kimg * 1000 + if ema_rampup is not None: + ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) + ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) + for p_ema, p in zip(G_ema.parameters(), G.parameters()): + p_ema.copy_(p.lerp(p_ema, ema_beta)) + for b_ema, b in zip(G_ema.buffers(), G.buffers()): + b_ema.copy_(b) + + # Update state. + cur_nimg += batch_size + batch_idx += 1 + + # Execute ADA heuristic. + if (ada_stats is not None) and (batch_idx % ada_interval == 0): + ada_stats.update() + adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) + augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) + + # Perform maintenance tasks once per tick. + done = (cur_nimg >= total_kimg * 1000) + if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): + continue + + # Print status line, accumulating the same information in stats_collector. + tick_end_time = time.time() + fields = [] + fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] + fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] + fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] + fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] + fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] + fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] + fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] + fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] + torch.cuda.reset_peak_memory_stats() + fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] + training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) + training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) + if rank == 0: + print(' '.join(fields)) + + # Check for abort. + if (not done) and (abort_fn is not None) and abort_fn(): + done = True + if rank == 0: + print() + print('Aborting...') + + # Save image snapshot. + if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): + images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size) + + # Save network snapshot. + snapshot_pkl = None + snapshot_data = None + if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): + snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: + if module is not None: + if num_gpus > 1: + misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') + module = copy.deepcopy(module).eval().requires_grad_(False).cpu() + snapshot_data[name] = module + del module # conserve memory + snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') + if rank == 0: + with open(snapshot_pkl, 'wb') as f: + pickle.dump(snapshot_data, f) + + # Evaluate metrics. + if (snapshot_data is not None) and (len(metrics) > 0): + if rank == 0: + print('Evaluating metrics...') + for metric in metrics: + result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], + dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) + if rank == 0: + metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) + stats_metrics.update(result_dict.results) + del snapshot_data # conserve memory + + # Collect statistics. + for phase in phases: + value = [] + if (phase.start_event is not None) and (phase.end_event is not None): + phase.end_event.synchronize() + value = phase.start_event.elapsed_time(phase.end_event) + training_stats.report0('Timing/' + phase.name, value) + stats_collector.update() + stats_dict = stats_collector.as_dict() + + # Update logs. + timestamp = time.time() + if stats_jsonl is not None: + fields = dict(stats_dict, timestamp=timestamp) + stats_jsonl.write(json.dumps(fields) + '\n') + stats_jsonl.flush() + if stats_tfevents is not None: + global_step = int(cur_nimg / 1e3) + walltime = timestamp - start_time + for name, value in stats_dict.items(): + stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) + for name, value in stats_metrics.items(): + stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) + stats_tfevents.flush() + if progress_fn is not None: + progress_fn(cur_nimg // 1000, total_kimg) + + # Update state. + cur_tick += 1 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - tick_end_time + if done: + break + + # Done. + if rank == 0: + print() + print('Exiting...') + +#---------------------------------------------------------------------------- diff --git a/util/utilgan.py b/util/utilgan.py new file mode 100644 index 0000000000000000000000000000000000000000..141a150eb631a21cd8f8d8c4717994e0fff8f27a --- /dev/null +++ b/util/utilgan.py @@ -0,0 +1,371 @@ +import os +import sys +import time +import math +import numpy as np +from scipy.ndimage import gaussian_filter +from scipy.interpolate import CubicSpline as CubSpline +from scipy.special import comb +import scipy +from imageio import imread + +import torch +import torch.nn.functional as F + +# from perlin import PerlinNoiseFactory as Perlin +# noise = Perlin(1) + +# def latent_noise(t, dim, noise_step=78564.543): + # latent = np.zeros((1, dim)) + # for i in range(dim): + # latent[0][i] = noise(t + i * noise_step) + # return latent + +def load_latents(npy_file): + key_latents = np.load(npy_file) + try: + key_latents = key_latents[key_latents.files[0]] + except: + pass + idx_file = os.path.splitext(npy_file)[0] + '.txt' + if os.path.exists(idx_file): + with open(idx_file) as f: + lat_idx = f.readline() + lat_idx = [int(l.strip()) for l in lat_idx.split(',') if '\n' not in l and len(l.strip())>0] + key_latents = [key_latents[i] for i in lat_idx] + return np.asarray(key_latents) + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = + +def get_z(shape, seed=None, uniform=False): + if seed is None: + seed = np.random.seed(int((time.time()%1) * 9999)) + rnd = np.random.RandomState(seed) + if uniform: + return rnd.uniform(0., 1., shape) + else: + return rnd.randn(*shape) # *x unpacks tuple/list to sequence + +def smoothstep(x, NN=1., xmin=0., xmax=1.): + N = math.ceil(NN) + x = np.clip((x - xmin) / (xmax - xmin), 0, 1) + result = 0 + for n in range(0, N+1): + result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n + result *= x**(N+1) + if NN != N: result = (x + result) / 2 + return result + +def lerp(z1, z2, num_steps, smooth=0.): + vectors = [] + xs = [step / (num_steps - 1) for step in range(num_steps)] + if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] + for x in xs: + interpol = z1 + (z2 - z1) * x + vectors.append(interpol) + return np.array(vectors) + +# interpolate on hypersphere +def slerp(z1, z2, num_steps, smooth=0.): + z1_norm = np.linalg.norm(z1) + z2_norm = np.linalg.norm(z2) + z2_normal = z2 * (z1_norm / z2_norm) + vectors = [] + xs = [step / (num_steps - 1) for step in range(num_steps)] + if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] + for x in xs: + interplain = z1 + (z2 - z1) * x + interp = z1 + (z2_normal - z1) * x + interp_norm = np.linalg.norm(interp) + interpol_normal = interplain * (z1_norm / interp_norm) + # interpol_normal = interp * (z1_norm / interp_norm) + vectors.append(interpol_normal) + return np.array(vectors) + +def cublerp(points, steps, fstep): + keys = np.array([i*fstep for i in range(steps)] + [steps*fstep]) + points = np.concatenate((points, np.expand_dims(points[0], 0))) + cspline = CubSpline(keys, points) + return cspline(range(steps*fstep+1)) + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = + +def latent_anima(shape, frames, transit, key_latents=None, smooth=0.5, cubic=False, gauss=False, seed=None, verbose=True): + if key_latents is None: + transit = int(max(1, min(frames//4, transit))) + steps = max(1, int(frames // transit)) + log = ' timeline: %d steps by %d' % (steps, transit) + + getlat = lambda : get_z(shape, seed=seed) + + # make key points + if key_latents is None: + key_latents = np.array([getlat() for i in range(steps)]) + + latents = np.expand_dims(key_latents[0], 0) + + # populate lerp between key points + if transit == 1: + latents = key_latents + else: + if cubic: + latents = cublerp(key_latents, steps, transit) + log += ', cubic' + else: + for i in range(steps): + zA = key_latents[i] + zB = key_latents[(i+1) % steps] + interps_z = slerp(zA, zB, transit, smooth=smooth) + latents = np.concatenate((latents, interps_z)) + latents = np.array(latents) + + if gauss: + lats_post = gaussian_filter(latents, [transit, 0, 0], mode="wrap") + lats_post = (lats_post / np.linalg.norm(lats_post, axis=-1, keepdims=True)) * math.sqrt(np.prod(shape)) + log += ', gauss' + latents = lats_post + + if verbose: print(log) + if latents.shape[0] > frames: # extra frame + latents = latents[1:] + return latents + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = + +def multimask(x, size, latmask=None, countHW=[1,1], delta=0.): + Hx, Wx = countHW + bcount = x.shape[0] + + if max(countHW) > 1: + W = x.shape[3] # width + H = x.shape[2] # height + if Wx > 1: + stripe_mask = [] + for i in range(Wx): + ch_mask = peak_roll(W, Wx, i, delta).unsqueeze(0).unsqueeze(0) # [1,1,w] th + ch_mask = ch_mask.repeat(1,H,1) # [1,h,w] + stripe_mask.append(ch_mask) + maskW = torch.cat(stripe_mask, 0).unsqueeze(1) # [x,1,h,w] + else: maskW = [1] + if Hx > 1: + stripe_mask = [] + for i in range(Hx): + ch_mask = peak_roll(H, Hx, i, delta).unsqueeze(1).unsqueeze(0) # [1,h,1] th + ch_mask = ch_mask.repeat(1,1,W) # [1,h,w] + stripe_mask.append(ch_mask) + maskH = torch.cat(stripe_mask, 0).unsqueeze(1) # [y,1,h,w] + else: maskH = [1] + + mask = [] + for i in range(Wx): + for j in range(Hx): + mask.append(maskW[i] * maskH[j]) + mask = torch.cat(mask, 0).unsqueeze(1) # [xy,1,h,w] + mask = mask.to(x.device) + x = torch.sum(x[:Hx*Wx] * mask, 0, keepdim=True) + + elif latmask is not None: + if len(latmask.shape) < 4: + latmask = latmask.unsqueeze(1) # [b,1,h,w] + lms = latmask.shape + if list(lms[2:]) != list(size) and np.prod(lms) > 1: + latmask = F.interpolate(latmask, size) # , mode='nearest' + latmask = latmask.type(x.dtype) + x = torch.sum(x[:lms[0]] * latmask, 0, keepdim=True) + else: + return x + + x = x.repeat(bcount,1,1,1) + return x # [b,f,h,w] + +def peak_roll(width, count, num, delta): + step = width // count + if width > step*2: + fill_range = torch.zeros([width-step*2]) + full_ax = torch.cat((peak(step, delta), fill_range), 0) + else: + full_ax = peak(step, delta)[:width] + if num == 0: + shift = max(width - (step//2), 0.) # must be positive! + else: + shift = step*num - (step//2) + full_ax = torch.roll(full_ax, shift, 0) + return full_ax # [width,] + +def peak(steps, delta): + x = torch.linspace(0.-delta, 1.+ delta, steps) + x_rev = torch.flip(x,[0]) + x = torch.cat((x, x_rev), 0) + x = torch.clip(x, 0., 1.) + return x # [steps*2,] + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = + +def ups2d(x, factor=2): + assert isinstance(factor, int) and factor >= 1 + if factor == 1: return x + s = x.shape + x = x.reshape(-1, s[1], s[2], 1, s[3], 1) + x = x.repeat(1, 1, 1, factor, 1, factor) + x = x.reshape(-1, s[1], s[2] * factor, s[3] * factor) + return x + +# Tiles an array around two points, allowing for pad lengths greater than the input length +# NB: if symm=True, every second tile is mirrored = messed up in GAN +# adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3 +def tile_pad(xt, padding, symm=True): + h, w = xt.shape[-2:] + left, right, top, bottom = padding + + def tile(x, minx, maxx, symm=True): + rng = maxx - minx + if symm is True: # triangular reflection + double_rng = 2*rng + mod = np.fmod(x - minx, double_rng) + normed_mod = np.where(mod < 0, mod+double_rng, mod) + out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx + else: # repeating tiles + mod = np.remainder(x - minx, rng) + out = mod + minx + return np.array(out, dtype=x.dtype) + + x_idx = np.arange(-left, w+right) + y_idx = np.arange(-top, h+bottom) + x_pad = tile(x_idx, -0.5, w-0.5, symm) + y_pad = tile(y_idx, -0.5, h-0.5, symm) + xx, yy = np.meshgrid(x_pad, y_pad) + return xt[..., yy, xx] + +def pad_up_to(x, size, type='centr'): + sh = x.shape[2:][::-1] + if list(x.shape[2:]) == list(size): return x + padding = [] + for i, s in enumerate(size[::-1]): + if 'side' in type.lower(): + padding = padding + [0, s-sh[i]] + else: # centr + p0 = (s-sh[i]) // 2 + p1 = s-sh[i] - p0 + padding = padding + [p0,p1] + y = tile_pad(x, padding, symm = 'symm' in type.lower()) + # if 'symm' in type.lower(): + # y = tile_pad(x, padding, symm=True) + # else: + # y = F.pad(x, padding, 'circular') + return y + +# scale_type may include pad, side, symm +def fix_size(x, size, scale_type='centr'): + if not len(x.shape) == 4: + raise Exception(" Wrong data rank, shape:", x.shape) + if x.shape[2:] == size: + return x + if (x.shape[2]*2, x.shape[3]*2) == size: + return ups2d(x) + + if scale_type.lower() == 'fit': + return F.interpolate(x, size, mode='nearest') # , align_corners=True + elif 'pad' in scale_type.lower(): + pass + else: # proportional scale to smaller side, then pad to bigger side + sh0 = x.shape[2:] + upsc = np.min(size) / np.min(sh0) + new_size = [int(sh0[i]*upsc) for i in [0,1]] + x = F.interpolate(x, new_size, mode='nearest') # , align_corners=True + + x = pad_up_to(x, size, scale_type) + return x + +# Make list of odd sizes for upsampling to arbitrary resolution +def hw_scales(size, base, n, keep_first_layers=None, verbose=False): + if isinstance(base, int): base = (base, base) + start_res = [int(b * 2 ** (-n)) for b in base] + + start_res[0] = int(start_res[0] * size[0] // base[0]) + start_res[1] = int(start_res[1] * size[1] // base[1]) + + hw_list = [] + + if base[0] != base[1] and verbose is True: + print(' size', size, 'base', base, 'start_res', start_res, 'n', n) + if keep_first_layers is not None and keep_first_layers > 0: + for i in range(keep_first_layers): + hw_list.append(start_res) + start_res = [x*2 for x in start_res] + n -= 1 + + ch = (size[0] / start_res[0]) ** (1/n) + cw = (size[1] / start_res[1]) ** (1/n) + for i in range(n): + h = math.floor(start_res[0] * ch**i) + w = math.floor(start_res[1] * cw**i) + hw_list.append((h,w)) + + hw_list.append(size) + return hw_list + +def calc_res(shape): + base0 = 2**int(np.log2(shape[0])) + base1 = 2**int(np.log2(shape[1])) + base = min(base0, base1) + min_res = min(shape[0], shape[1]) + + def int_log2(xs, base): + return [x * 2**(2-int(np.log2(base))) % 1 == 0 for x in xs] + if min_res != base or max(*shape) / min(*shape) >= 2: + if np.log2(base) < 10 and all(int_log2(shape, base*2)): + base = base * 2 + + return base # , [shape[0]/base, shape[1]/base] + +def calc_init_res(shape, resolution=None): + if len(shape) == 1: + shape = [shape[0], shape[0], 1] + elif len(shape) == 2: + shape = [*shape, 1] + size = shape[:2] if shape[2] < min(*shape[:2]) else shape[1:] # fewer colors than pixels + if resolution is None: + resolution = calc_res(size) + res_log2 = int(np.log2(resolution)) + init_res = [int(s * 2**(2-res_log2)) for s in size] + return init_res, resolution, res_log2 + +def basename(file): + return os.path.splitext(os.path.basename(file))[0] + +def file_list(path, ext=None, subdir=None): + if subdir is True: + files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] + else: + files = [os.path.join(path, f) for f in os.listdir(path)] + if ext is not None: + if isinstance(ext, list): + files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext] + elif isinstance(ext, str): + files = [f for f in files if f.endswith(ext)] + else: + print(' Unknown extension/type for file list!') + return sorted([f for f in files if os.path.isfile(f)]) + +def dir_list(in_dir): + dirs = [os.path.join(in_dir, x) for x in os.listdir(in_dir)] + return sorted([f for f in dirs if os.path.isdir(f)]) + +def img_list(path, subdir=None): + if subdir is True: + files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] + else: + files = [os.path.join(path, f) for f in os.listdir(path)] + files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ['jpg', 'jpeg', 'png', 'ppm', 'tif']] + return sorted([f for f in files if os.path.isfile(f)]) + +def img_read(path): + img = imread(path) + # 8bit to 256bit + if (img.ndim == 2) or (img.shape[2] == 1): + img = np.dstack((img,img,img)) + # rgba to rgb + if img.shape[2] == 4: + img = img[:,:,:3] + return img + \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/common.py b/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..4813fe311ee40720697e4862c5fbfad811d39237 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,87 @@ +import cv2 +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt + + +# Log images +def log_input_image(x, opts): + if opts.label_nc == 0: + return tensor2im(x) + elif opts.label_nc == 1: + return tensor2sketch(x) + else: + return tensor2map(x) + + +def tensor2im(var): + var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() + var = ((var + 1) / 2) + var[var < 0] = 0 + var[var > 1] = 1 + var = var * 255 + return Image.fromarray(var.astype('uint8')) + + +def tensor2map(var): + mask = np.argmax(var.data.cpu().numpy(), axis=0) + colors = get_colors() + mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) + for class_idx in np.unique(mask): + mask_image[mask == class_idx] = colors[class_idx] + mask_image = mask_image.astype('uint8') + return Image.fromarray(mask_image) + + +def tensor2sketch(var): + im = var[0].cpu().detach().numpy() + im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) + im = (im * 255).astype(np.uint8) + return Image.fromarray(im) + + +# Visualization utils +def get_colors(): + # currently support up to 19 classes (for the celebs-hq-mask dataset) + colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], + [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], + [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] + return colors + + +def vis_faces(log_hooks): + display_count = len(log_hooks) + fig = plt.figure(figsize=(8, 4 * display_count)) + gs = fig.add_gridspec(display_count, 3) + for i in range(display_count): + hooks_dict = log_hooks[i] + fig.add_subplot(gs[i, 0]) + if 'diff_input' in hooks_dict: + vis_faces_with_id(hooks_dict, fig, gs, i) + else: + vis_faces_no_id(hooks_dict, fig, gs, i) + plt.tight_layout() + return fig + + +def vis_faces_with_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face']) + plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), + float(hooks_dict['diff_target']))) + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) + + +def vis_faces_no_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face'], cmap="gray") + plt.title('Input') + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target') + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output') diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ba79f4a2d5cc2b97dce76d87bf6e7cdebbc257 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,25 @@ +""" +Code adopted from pix2pixHD: +https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py +""" +import os + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,13 @@ + +def aggregate_loss_dict(agg_loss_dict): + mean_vals = {} + for output in agg_loss_dict: + for key in output: + mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] + for key in mean_vals: + if len(mean_vals[key]) > 0: + mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) + else: + print('{} has no value'.format(key)) + mean_vals[key] = 0 + return mean_vals diff --git a/utils/wandb_utils.py b/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0061eb569dee40bbe68f244b286976412fe6dece --- /dev/null +++ b/utils/wandb_utils.py @@ -0,0 +1,47 @@ +import datetime +import os +import numpy as np +import wandb + +from utils import common + + +class WBLogger: + + def __init__(self, opts): + wandb_run_name = os.path.basename(opts.exp_dir) + wandb.init(project="pixel2style2pixel", config=vars(opts), name=wandb_run_name) + + @staticmethod + def log_best_model(): + wandb.run.summary["best-model-save-time"] = datetime.datetime.now() + + @staticmethod + def log(prefix, metrics_dict, global_step): + log_dict = {f'{prefix}_{key}': value for key, value in metrics_dict.items()} + log_dict["global_step"] = global_step + wandb.log(log_dict) + + @staticmethod + def log_dataset_wandb(dataset, dataset_name, n_images=16): + idxs = np.random.choice(a=range(len(dataset)), size=n_images, replace=False) + data = [wandb.Image(dataset.source_paths[idx]) for idx in idxs] + wandb.log({f"{dataset_name} Data Samples": data}) + + @staticmethod + def log_images_to_wandb(x, y, y_hat, id_logs, prefix, step, opts): + im_data = [] + column_names = ["Source", "Target", "Output"] + if id_logs is not None: + column_names.append("ID Diff Output to Target") + for i in range(len(x)): + cur_im_data = [ + wandb.Image(common.log_input_image(x[i], opts)), + wandb.Image(common.tensor2im(y[i])), + wandb.Image(common.tensor2im(y_hat[i])), + ] + if id_logs is not None: + cur_im_data.append(id_logs[i]["diff_target"]) + im_data.append(cur_im_data) + outputs_table = wandb.Table(data=im_data, columns=column_names) + wandb.log({f"{prefix.title()} Step {step} Output Samples": outputs_table})