diff --git a/.gitattributes b/.gitattributes index 818d649bf21cdef29b21f885c8f770f9baa1714e..957b2579c6ef20995a09efd9a17f8fd90606f5ed 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,6 +1,7 @@ *.7z filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text *.bz2 filter=lfs diff=lfs merge=lfs -text *.ftz filter=lfs diff=lfs merge=lfs -text *.gz filter=lfs diff=lfs merge=lfs -text @@ -9,13 +10,9 @@ *.lfs.* filter=lfs diff=lfs merge=lfs -text *.model filter=lfs diff=lfs merge=lfs -text *.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text *.onnx filter=lfs diff=lfs merge=lfs -text *.ot filter=lfs diff=lfs merge=lfs -text *.parquet filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text *.pb filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text @@ -24,8 +21,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.tar.* filter=lfs diff=lfs merge=lfs -text *.tflite filter=lfs diff=lfs merge=lfs -text *.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..464c587b66b0cdb32019704a37e90e9a4252c531 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Min Jin Chong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index ebaf1fe08d91e6028e030ca148b0fd75830e2ed1..98158fd9dbff2afc2f0d207cfbd825bf48a31844 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,38 @@ --- -title: JoJoGan Powerhow2 -emoji: 📚 -colorFrom: red -colorTo: blue +title: JoJoGAN +emoji: 🌍 +colorFrom: green +colorTo: yellow sdk: gradio -sdk_version: 3.2 +sdk_version: 3.1.1 app_file: app.py pinned: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Configuration + +`title`: _string_ +Display title for the Space + +`emoji`: _string_ +Space emoji (emoji-only character allowed) + +`colorFrom`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`colorTo`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`sdk`: _string_ +Can be either `gradio` or `streamlit` + +`sdk_version` : _string_ +Only applicable for `streamlit` SDK. +See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. + +`app_file`: _string_ +Path to your main application file (which contains either `gradio` or `streamlit` Python code). +Path is relative to the root of the repository. + +`pinned`: _boolean_ +Whether the Space stays on top of your list. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..df2814cae8ab12b97c33c34c03a6498eb703d0e9 --- /dev/null +++ b/app.py @@ -0,0 +1,204 @@ +import os +from PIL import Image +import torch +import gradio as gr +import torch +torch.backends.cudnn.benchmark = True +from torchvision import transforms, utils +from util import * +from PIL import Image +import math +import random +import numpy as np +from torch import nn, autograd, optim +from torch.nn import functional as F +from tqdm import tqdm +import lpips +from model import * + + +#from e4e_projection import projection as e4e_projection + +from copy import deepcopy +import imageio + +import os +import sys +import numpy as np +from PIL import Image +import torch +import torchvision.transforms as transforms +from argparse import Namespace +from e4e.models.psp import pSp +from util import * +from huggingface_hub import hf_hub_download + +device= 'cpu' +model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt") +ckpt = torch.load(model_path_e, map_location='cpu') +opts = ckpt['opts'] +opts['checkpoint_path'] = model_path_e +opts= Namespace(**opts) +net = pSp(opts, device).eval().to(device) + +@ torch.no_grad() +def projection(img, name, device='cuda'): + + + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + img = transform(img).unsqueeze(0).to(device) + images, w_plus = net(img, randomize_noise=False, return_latents=True) + result_file = {} + result_file['latent'] = w_plus[0] + torch.save(result_file, name) + return w_plus[0] + + + + +device = 'cpu' + + +latent_dim = 512 + +model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt") +original_generator = Generator(1024, latent_dim, 8, 2).to(device) +ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage) +original_generator.load_state_dict(ckpt["g_ema"], strict=False) +mean_latent = original_generator.mean_latent(10000) + +generatorjojo = deepcopy(original_generator) + +generatordisney = deepcopy(original_generator) + +generatorjinx = deepcopy(original_generator) + +generatorcaitlyn = deepcopy(original_generator) + +generatoryasuho = deepcopy(original_generator) + +generatorarcanemulti = deepcopy(original_generator) + +generatorart = deepcopy(original_generator) + +generatorspider = deepcopy(original_generator) + +generatorsketch = deepcopy(original_generator) + + +transform = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] +) + + + + +modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt") + + +ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage) +generatorjojo.load_state_dict(ckptjojo["g"], strict=False) + + +modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt") + +ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage) +generatordisney.load_state_dict(ckptdisney["g"], strict=False) + + +modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt") + +ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage) +generatorjinx.load_state_dict(ckptjinx["g"], strict=False) + + +modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt") + +ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage) +generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False) + + +modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt") + +ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage) +generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False) + + +model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt") + +ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage) +generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False) + + +modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt") + +ckptart = torch.load(modelart, map_location=lambda storage, loc: storage) +generatorart.load_state_dict(ckptart["g"], strict=False) + + +modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt") + +ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage) +generatorspider.load_state_dict(ckptspider["g"], strict=False) + +modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt") + +ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage) +generatorsketch.load_state_dict(ckptsketch["g"], strict=False) + +def inference(img, model): + img.save('out.jpg') + aligned_face = align_face('out.jpg') + + my_w = projection(aligned_face, "test.pt", device).unsqueeze(0) + if model == 'JoJo': + with torch.no_grad(): + my_sample = generatorjojo(my_w, input_is_latent=True) + elif model == 'Disney': + with torch.no_grad(): + my_sample = generatordisney(my_w, input_is_latent=True) + elif model == 'Jinx': + with torch.no_grad(): + my_sample = generatorjinx(my_w, input_is_latent=True) + elif model == 'Caitlyn': + with torch.no_grad(): + my_sample = generatorcaitlyn(my_w, input_is_latent=True) + elif model == 'Yasuho': + with torch.no_grad(): + my_sample = generatoryasuho(my_w, input_is_latent=True) + elif model == 'Arcane Multi': + with torch.no_grad(): + my_sample = generatorarcanemulti(my_w, input_is_latent=True) + elif model == 'Art': + with torch.no_grad(): + my_sample = generatorart(my_w, input_is_latent=True) + elif model == 'Spider-Verse': + with torch.no_grad(): + my_sample = generatorspider(my_w, input_is_latent=True) + else: + with torch.no_grad(): + my_sample = generatorsketch(my_w, input_is_latent=True) + + + npimage = my_sample[0].permute(1, 2, 0).detach().numpy() + imageio.imwrite('filename.jpeg', npimage) + return 'filename.jpeg' + +title = "JoJoGAN" +description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." + +article = "

JoJoGAN: One Shot Face Stylization| Github Repo Pytorch

visitor badge
" + +examples=[['mona.png','Jinx']] +gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch() diff --git a/e4e/.gitignore b/e4e/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b6e47617de110dea7ca47e087ff1347cc2646eda --- /dev/null +++ b/e4e/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/e4e/criteria/__init__.py b/e4e/criteria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/criteria/id_loss.py b/e4e/criteria/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..bab806172eff18c0630536ae96817508c3197b8b --- /dev/null +++ b/e4e/criteria/id_loss.py @@ -0,0 +1,47 @@ +import torch +from torch import nn +from configs.paths_config import model_paths +from models.encoders.model_irse import Backbone + + +class IDLoss(nn.Module): + def __init__(self): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + for module in [self.facenet, self.face_pool]: + for param in module.parameters(): + param.requires_grad = False + + def extract_feats(self, x): + x = x[:, :, 35:223, 32:220] # Crop interesting region + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + id_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + id_logs.append({'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views)}) + loss += 1 - diff_target + id_diff = float(diff_target) - float(diff_views) + sim_improvement += id_diff + count += 1 + + return loss / count, sim_improvement / count, id_logs diff --git a/e4e/criteria/lpips/__init__.py b/e4e/criteria/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/criteria/lpips/lpips.py b/e4e/criteria/lpips/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..1add6acc84c1c04cfcb536cf31ec5acdf24b716b --- /dev/null +++ b/e4e/criteria/lpips/lpips.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from criteria.lpips.networks import get_network, LinLayers +from criteria.lpips.utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type).to("cuda") + + # linear layers + self.lin = LinLayers(self.net.n_channels_list).to("cuda") + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0)) / x.shape[0] diff --git a/e4e/criteria/lpips/networks.py b/e4e/criteria/lpips/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0d13ad2d560278f16586da68d3a5eadb26e746 --- /dev/null +++ b/e4e/criteria/lpips/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from criteria.lpips.utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(True).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) \ No newline at end of file diff --git a/e4e/criteria/lpips/utils.py b/e4e/criteria/lpips/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5 --- /dev/null +++ b/e4e/criteria/lpips/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/e4e/criteria/moco_loss.py b/e4e/criteria/moco_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8fb13fbd426202cff9014c876c85b0d5c4ec6a9d --- /dev/null +++ b/e4e/criteria/moco_loss.py @@ -0,0 +1,71 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from configs.paths_config import model_paths + + +class MocoLoss(nn.Module): + + def __init__(self, opts): + super(MocoLoss, self).__init__() + print("Loading MOCO model from path: {}".format(model_paths["moco"])) + self.model = self.__load_model() + self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + @staticmethod + def __load_model(): + import torchvision.models as models + model = models.__dict__["resnet50"]() + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ['fc.weight', 'fc.bias']: + param.requires_grad = False + checkpoint = torch.load(model_paths['moco'], map_location="cpu") + state_dict = checkpoint['state_dict'] + # rename moco pre-trained keys + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): + # remove prefix + state_dict[k[len("module.encoder_q."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + # remove output layer + model = nn.Sequential(*list(model.children())[:-1]).cuda() + return model + + def extract_feats(self, x): + x = F.interpolate(x, size=224) + x_feats = self.model(x) + x_feats = nn.functional.normalize(x_feats, dim=1) + x_feats = x_feats.squeeze() + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + sim_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + sim_logs.append({'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views)}) + loss += 1 - diff_target + sim_diff = float(diff_target) - float(diff_views) + sim_improvement += sim_diff + count += 1 + + return loss / count, sim_improvement / count, sim_logs diff --git a/e4e/criteria/w_norm.py b/e4e/criteria/w_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..a45ab6f67d8a3f7051be4b7236fa2f38446fd2c1 --- /dev/null +++ b/e4e/criteria/w_norm.py @@ -0,0 +1,14 @@ +import torch +from torch import nn + + +class WNormLoss(nn.Module): + + def __init__(self, start_from_latent_avg=True): + super(WNormLoss, self).__init__() + self.start_from_latent_avg = start_from_latent_avg + + def forward(self, latent, latent_avg=None): + if self.start_from_latent_avg: + latent = latent - latent_avg + return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] diff --git a/e4e/datasets/__init__.py b/e4e/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/datasets/gt_res_dataset.py b/e4e/datasets/gt_res_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0beacfee5335aa10aa7e8b7cabe206d7f9a56f7 --- /dev/null +++ b/e4e/datasets/gt_res_dataset.py @@ -0,0 +1,32 @@ +#!/usr/bin/python +# encoding: utf-8 +import os +from torch.utils.data import Dataset +from PIL import Image +import torch + +class GTResDataset(Dataset): + + def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): + self.pairs = [] + for f in os.listdir(root_path): + image_path = os.path.join(root_path, f) + gt_path = os.path.join(gt_dir, f) + if f.endswith(".jpg") or f.endswith(".png"): + self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) + self.transform = transform + self.transform_train = transform_train + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + from_path, to_path, _ = self.pairs[index] + from_im = Image.open(from_path).convert('RGB') + to_im = Image.open(to_path).convert('RGB') + + if self.transform: + to_im = self.transform(to_im) + from_im = self.transform(from_im) + + return from_im, to_im diff --git a/e4e/datasets/images_dataset.py b/e4e/datasets/images_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..00c54c7db944569a749af4c6f0c4d99fcc37f9cc --- /dev/null +++ b/e4e/datasets/images_dataset.py @@ -0,0 +1,33 @@ +from torch.utils.data import Dataset +from PIL import Image +from utils import data_utils + + +class ImagesDataset(Dataset): + + def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): + self.source_paths = sorted(data_utils.make_dataset(source_root)) + self.target_paths = sorted(data_utils.make_dataset(target_root)) + self.source_transform = source_transform + self.target_transform = target_transform + self.opts = opts + + def __len__(self): + return len(self.source_paths) + + def __getitem__(self, index): + from_path = self.source_paths[index] + from_im = Image.open(from_path) + from_im = from_im.convert('RGB') + + to_path = self.target_paths[index] + to_im = Image.open(to_path).convert('RGB') + if self.target_transform: + to_im = self.target_transform(to_im) + + if self.source_transform: + from_im = self.source_transform(from_im) + else: + from_im = to_im + + return from_im, to_im diff --git a/e4e/datasets/inference_dataset.py b/e4e/datasets/inference_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb577d7b538d634f27013c2784d2ea32143154cb --- /dev/null +++ b/e4e/datasets/inference_dataset.py @@ -0,0 +1,25 @@ +from torch.utils.data import Dataset +from PIL import Image +from utils import data_utils + + +class InferenceDataset(Dataset): + + def __init__(self, root, opts, transform=None, preprocess=None): + self.paths = sorted(data_utils.make_dataset(root)) + self.transform = transform + self.preprocess = preprocess + self.opts = opts + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + from_path = self.paths[index] + if self.preprocess is not None: + from_im = self.preprocess(from_path) + else: + from_im = Image.open(from_path).convert('RGB') + if self.transform: + from_im = self.transform(from_im) + return from_im diff --git a/e4e/editings/ganspace.py b/e4e/editings/ganspace.py new file mode 100644 index 0000000000000000000000000000000000000000..0c286a421280c542e9776a75e64bb65409da8fc7 --- /dev/null +++ b/e4e/editings/ganspace.py @@ -0,0 +1,22 @@ +import torch + + +def edit(latents, pca, edit_directions): + edit_latents = [] + for latent in latents: + for pca_idx, start, end, strength in edit_directions: + delta = get_delta(pca, latent, pca_idx, strength) + delta_padded = torch.zeros(latent.shape).to('cuda') + delta_padded[start:end] += delta.repeat(end - start, 1) + edit_latents.append(latent + delta_padded) + return torch.stack(edit_latents) + + +def get_delta(pca, latent, idx, strength): + # pca: ganspace checkpoint. latent: (16, 512) w+ + w_centered = latent - pca['mean'].to('cuda') + lat_comp = pca['comp'].to('cuda') + lat_std = pca['std'].to('cuda') + w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx] + delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx] + return delta diff --git a/e4e/editings/ganspace_pca/cars_pca.pt b/e4e/editings/ganspace_pca/cars_pca.pt new file mode 100644 index 0000000000000000000000000000000000000000..41c2618317f92be5089f99e1f566e9a45650b1bb --- /dev/null +++ b/e4e/editings/ganspace_pca/cars_pca.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392 +size 167562 diff --git a/e4e/editings/ganspace_pca/ffhq_pca.pt b/e4e/editings/ganspace_pca/ffhq_pca.pt new file mode 100644 index 0000000000000000000000000000000000000000..8c8be273036803a6845ad067c8f659867343932d --- /dev/null +++ b/e4e/editings/ganspace_pca/ffhq_pca.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36 +size 167562 diff --git a/e4e/editings/interfacegan_directions/age.pt b/e4e/editings/interfacegan_directions/age.pt new file mode 100644 index 0000000000000000000000000000000000000000..64cdd22d071c643c59ce94d58334f09f647e8a83 --- /dev/null +++ b/e4e/editings/interfacegan_directions/age.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0 +size 2808 diff --git a/e4e/editings/interfacegan_directions/pose.pt b/e4e/editings/interfacegan_directions/pose.pt new file mode 100644 index 0000000000000000000000000000000000000000..2b6ceffe285303e7b2b09287167dba965283570b --- /dev/null +++ b/e4e/editings/interfacegan_directions/pose.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d +size 37624 diff --git a/e4e/editings/interfacegan_directions/smile.pt b/e4e/editings/interfacegan_directions/smile.pt new file mode 100644 index 0000000000000000000000000000000000000000..eeedc44689954510ce2c3bb585f9f9968ee06825 --- /dev/null +++ b/e4e/editings/interfacegan_directions/smile.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653 +size 2808 diff --git a/e4e/editings/latent_editor.py b/e4e/editings/latent_editor.py new file mode 100644 index 0000000000000000000000000000000000000000..4bebca2f5c86f71b58fa1f30d24bfcb0da06d88f --- /dev/null +++ b/e4e/editings/latent_editor.py @@ -0,0 +1,45 @@ +import torch +import sys +sys.path.append(".") +sys.path.append("..") +from editings import ganspace, sefa +from utils.common import tensor2im + + +class LatentEditor(object): + def __init__(self, stylegan_generator, is_cars=False): + self.generator = stylegan_generator + self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output. + + def apply_ganspace(self, latent, ganspace_pca, edit_directions): + edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions) + return self._latents_to_image(edit_latents) + + def apply_interfacegan(self, latent, direction, factor=1, factor_range=None): + edit_latents = [] + if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) + for f in range(*factor_range): + edit_latent = latent + f * direction + edit_latents.append(edit_latent) + edit_latents = torch.cat(edit_latents) + else: + edit_latents = latent + factor * direction + return self._latents_to_image(edit_latents) + + def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs): + edit_latents = sefa.edit(self.generator, latent, indices, **kwargs) + return self._latents_to_image(edit_latents) + + # Currently, in order to apply StyleFlow editings, one should run inference, + # save the latent codes and load them form the official StyleFlow repository. + # def apply_styleflow(self): + # pass + + def _latents_to_image(self, latents): + with torch.no_grad(): + images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True) + if self.is_cars: + images = images[:, :, 64:448, :] # 512x512 -> 384x512 + horizontal_concat_image = torch.cat(list(images), 2) + final_image = tensor2im(horizontal_concat_image) + return final_image diff --git a/e4e/editings/sefa.py b/e4e/editings/sefa.py new file mode 100644 index 0000000000000000000000000000000000000000..db7083ce463b765a7cf452807883a3b85fb63fa5 --- /dev/null +++ b/e4e/editings/sefa.py @@ -0,0 +1,46 @@ +import torch +import numpy as np +from tqdm import tqdm + + +def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11): + + layers, boundaries, values = factorize_weight(generator, indices) + codes = latents.detach().cpu().numpy() # (1,18,512) + + # Generate visualization pages. + distances = np.linspace(start_distance, end_distance, step) + num_sam = num_samples + num_sem = semantics + + edited_latents = [] + for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False): + boundary = boundaries[sem_id:sem_id + 1] + for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False): + code = codes[sam_id:sam_id + 1] + for col_id, d in enumerate(distances, start=1): + temp_code = code.copy() + temp_code[:, layers, :] += boundary * d + edited_latents.append(torch.from_numpy(temp_code).float().cuda()) + return torch.cat(edited_latents) + + +def factorize_weight(g_ema, layers='all'): + + weights = [] + if layers == 'all' or 0 in layers: + weight = g_ema.conv1.conv.modulation.weight.T + weights.append(weight.cpu().detach().numpy()) + + if layers == 'all': + layers = list(range(g_ema.num_layers - 1)) + else: + layers = [l - 1 for l in layers if l != 0] + + for idx in layers: + weight = g_ema.convs[idx].conv.modulation.weight.T + weights.append(weight.cpu().detach().numpy()) + weight = np.concatenate(weights, axis=1).astype(np.float32) + weight = weight / np.linalg.norm(weight, axis=0, keepdims=True) + eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T)) + return layers, eigen_vectors.T, eigen_values diff --git a/e4e/environment/e4e_env.yaml b/e4e/environment/e4e_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f537615ebb47afd74b5a9856fb9cbea2e0c4bf4 --- /dev/null +++ b/e4e/environment/e4e_env.yaml @@ -0,0 +1,73 @@ +name: e4e_env +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - ca-certificates=2020.4.5.1=hecc5488_0 + - certifi=2020.4.5.1=py36h9f0ad1d_0 + - libedit=3.1.20181209=hc058e9b_0 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=9.1.0=hdf63c60_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.0=hc9558a2_0 + - openssl=1.1.1g=h516909a_0 + - pip=20.0.2=py36_3 + - python=3.6.7=h0371630_0 + - python_abi=3.6=1_cp36m + - readline=7.0=h7b6447c_5 + - setuptools=46.4.0=py36_0 + - sqlite=3.31.1=h62c20be_1 + - tk=8.6.8=hbc83047_0 + - wheel=0.34.2=py36_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - pip: + - absl-py==0.9.0 + - cachetools==4.1.0 + - chardet==3.0.4 + - cycler==0.10.0 + - decorator==4.4.2 + - future==0.18.2 + - google-auth==1.15.0 + - google-auth-oauthlib==0.4.1 + - grpcio==1.29.0 + - idna==2.9 + - imageio==2.8.0 + - importlib-metadata==1.6.0 + - kiwisolver==1.2.0 + - markdown==3.2.2 + - matplotlib==3.2.1 + - mxnet==1.6.0 + - networkx==2.4 + - numpy==1.18.4 + - oauthlib==3.1.0 + - opencv-python==4.2.0.34 + - pillow==7.1.2 + - protobuf==3.12.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyparsing==2.4.7 + - python-dateutil==2.8.1 + - pytorch-lightning==0.7.1 + - pywavelets==1.1.1 + - requests==2.23.0 + - requests-oauthlib==1.3.0 + - rsa==4.0 + - scikit-image==0.17.2 + - scipy==1.4.1 + - six==1.15.0 + - tensorboard==2.2.1 + - tensorboard-plugin-wit==1.6.0.post3 + - tensorboardx==1.9 + - tifffile==2020.5.25 + - torch==1.6.0 + - torchvision==0.7.1 + - tqdm==4.46.0 + - urllib3==1.25.9 + - werkzeug==1.0.1 + - zipp==3.1.0 + - pyaml +prefix: ~/anaconda3/envs/e4e_env + diff --git a/e4e/metrics/LEC.py b/e4e/metrics/LEC.py new file mode 100644 index 0000000000000000000000000000000000000000..3eef2d2f00a4d757a56b6e845a8fde16aab306ab --- /dev/null +++ b/e4e/metrics/LEC.py @@ -0,0 +1,134 @@ +import sys +import argparse +import torch +import numpy as np +from torch.utils.data import DataLoader + +sys.path.append(".") +sys.path.append("..") + +from configs import data_configs +from datasets.images_dataset import ImagesDataset +from utils.model_utils import setup_model + + +class LEC: + def __init__(self, net, is_cars=False): + """ + Latent Editing Consistency metric as proposed in the main paper. + :param net: e4e model loaded over the pSp framework. + :param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images. + """ + self.net = net + self.is_cars = is_cars + + def _encode(self, images): + """ + Encodes the given images into StyleGAN's latent space. + :param images: Tensor of shape NxCxHxW representing the images to be encoded. + :return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space). + """ + codes = self.net.encoder(images) + assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}" + # normalize with respect to the center of an average face + if self.net.opts.start_from_latent_avg: + codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1) + return codes + + def _generate(self, codes): + """ + Generate the StyleGAN2 images of the given codes + :param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space). + :return: Tensor of shape NxCxHxW representing the generated images. + """ + images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True) + images = self.net.face_pool(images) + if self.is_cars: + images = images[:, :, 32:224, :] + return images + + @staticmethod + def _filter_outliers(arr): + arr = np.array(arr) + + lo = np.percentile(arr, 1, interpolation="lower") + hi = np.percentile(arr, 99, interpolation="higher") + return np.extract( + np.logical_and(lo <= arr, arr <= hi), arr + ) + + def calculate_metric(self, data_loader, edit_function, inverse_edit_function): + """ + Calculate the LEC metric score. + :param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader. + :param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the + latent space. + :param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the + `edit_function` parameter. + :return: The LEC metric score. + """ + distances = [] + with torch.no_grad(): + for batch in data_loader: + x, _ = batch + inputs = x.to(device).float() + + codes = self._encode(inputs) + edited_codes = edit_function(codes) + edited_image = self._generate(edited_codes) + edited_image_inversion_codes = self._encode(edited_image) + inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes) + + dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean() + distances.append(dist.to("cpu").numpy()) + + distances = self._filter_outliers(distances) + return distances.mean() + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser(description="LEC metric calculator") + + parser.add_argument("--batch", type=int, default=8, help="batch size for the models") + parser.add_argument("--images_dir", type=str, default=None, + help="Path to the images directory on which we calculate the LEC score") + parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints") + + args = parser.parse_args() + print(args) + + net, opts = setup_model(args.ckpt, device) + dataset_args = data_configs.DATASETS[opts.dataset_type] + transforms_dict = dataset_args['transforms'](opts).get_transforms() + + images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir + test_dataset = ImagesDataset(source_root=images_directory, + target_root=images_directory, + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_test'], + opts=opts) + + data_loader = DataLoader(test_dataset, + batch_size=args.batch, + shuffle=False, + num_workers=2, + drop_last=True) + + print(f'dataset length: {len(test_dataset)}') + + # In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric. + # Change the provided example according to your domain and needs. + direction = torch.load('../editings/interfacegan_directions/age.pt').to(device) + + def edit_func_example(codes): + return codes + 3 * direction + + + def inverse_edit_func_example(codes): + return codes - 3 * direction + + lec = LEC(net, is_cars='car' in opts.dataset_type) + result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example) + print(f"LEC: {result}") diff --git a/e4e/models/__init__.py b/e4e/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/models/discriminator.py b/e4e/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..16bf3722c7f2e35cdc9bd177a33ed0975e67200d --- /dev/null +++ b/e4e/models/discriminator.py @@ -0,0 +1,20 @@ +from torch import nn + + +class LatentCodesDiscriminator(nn.Module): + def __init__(self, style_dim, n_mlp): + super().__init__() + + self.style_dim = style_dim + + layers = [] + for i in range(n_mlp-1): + layers.append( + nn.Linear(style_dim, style_dim) + ) + layers.append(nn.LeakyReLU(0.2)) + layers.append(nn.Linear(512, 1)) + self.mlp = nn.Sequential(*layers) + + def forward(self, w): + return self.mlp(w) diff --git a/e4e/models/encoders/__init__.py b/e4e/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/models/encoders/helpers.py b/e4e/models/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a58b34ea5ca6912fe53c63dede0a8696f5c024 --- /dev/null +++ b/e4e/models/encoders/helpers.py @@ -0,0 +1,140 @@ +from collections import namedtuple +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +def _upsample_add(x, y): + """Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + """ + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y diff --git a/e4e/models/encoders/model_irse.py b/e4e/models/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..6a94d67542f961ff6533f0335cf4cb0fa54024fb --- /dev/null +++ b/e4e/models/encoders/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/e4e/models/encoders/psp_encoders.py b/e4e/models/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..dc49acd11f062cbd29f839ee3c04bce7fa84f479 --- /dev/null +++ b/e4e/models/encoders/psp_encoders.py @@ -0,0 +1,200 @@ +from enum import Enum +import math +import numpy as np +import torch +from torch import nn +from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module + +from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add +from e4e.models.stylegan2.model import EqualLinear + + +class ProgressiveStage(Enum): + WTraining = 0 + Delta1Training = 1 + Delta2Training = 2 + Delta3Training = 3 + Delta4Training = 4 + Delta5Training = 5 + Delta6Training = 6 + Delta7Training = 7 + Delta8Training = 8 + Delta9Training = 9 + Delta10Training = 10 + Delta11Training = 11 + Delta12Training = 12 + Delta13Training = 13 + Delta14Training = 14 + Delta15Training = 15 + Delta16Training = 16 + Delta17Training = 17 + Inference = 18 + + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = _upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = _upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out + + +class Encoder4Editing(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(Encoder4Editing, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + self.progressive_stage = ProgressiveStage.Inference + + def get_deltas_starting_dimensions(self): + ''' Get a list of the initial dimension of every delta from which it is applied ''' + return list(range(self.style_count)) # Each dimension has a delta applied to it + + def set_progressive_stage(self, new_stage: ProgressiveStage): + self.progressive_stage = new_stage + print('Changed progressive stage to: ', new_stage) + + def forward(self, x): + x = self.input_layer(x) + + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + # Infer main W and duplicate it + w0 = self.styles[0](c3) + w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) + stage = self.progressive_stage.value + features = c3 + for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas + if i == self.coarse_ind: + p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features + features = p2 + elif i == self.middle_ind: + p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features + features = p1 + delta_i = self.styles[i](features) + w[:, i] += delta_i + return w diff --git a/e4e/models/latent_codes_pool.py b/e4e/models/latent_codes_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..0281d4b5e80f8eb26e824fa35b4f908dcb6634e6 --- /dev/null +++ b/e4e/models/latent_codes_pool.py @@ -0,0 +1,55 @@ +import random +import torch + + +class LatentCodesPool: + """This class implements latent codes buffer that stores previously generated w latent codes. + This buffer enables us to update discriminators using a history of generated w's + rather than the ones produced by the latest encoder. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_ws = 0 + self.ws = [] + + def query(self, ws): + """Return w's from the pool. + Parameters: + ws: the latest generated w's from the generator + Returns w's from the buffer. + By 50/100, the buffer will return input w's. + By 50/100, the buffer will return w's previously stored in the buffer, + and insert the current w's to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return ws + return_ws = [] + for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) + # w = torch.unsqueeze(image.data, 0) + if w.ndim == 2: + i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate + w = w[i] + self.handle_w(w, return_ws) + return_ws = torch.stack(return_ws, 0) # collect all the images and return + return return_ws + + def handle_w(self, w, return_ws): + if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer + self.num_ws = self.num_ws + 1 + self.ws.append(w) + return_ws.append(w) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.ws[random_id].clone() + self.ws[random_id] = w + return_ws.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_ws.append(w) diff --git a/e4e/models/psp.py b/e4e/models/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..36c0b2b7b3fdd28bc32272d0d8fcff24e4848355 --- /dev/null +++ b/e4e/models/psp.py @@ -0,0 +1,99 @@ +import matplotlib + +matplotlib.use('Agg') +import torch +from torch import nn +from e4e.models.encoders import psp_encoders +from e4e.models.stylegan2.model import Generator +from e4e.configs.paths_config import model_paths + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class pSp(nn.Module): + + def __init__(self, opts, device): + super(pSp, self).__init__() + self.opts = opts + self.device = device + # Define architecture + self.encoder = self.set_encoder() + self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_encoder(self): + if self.opts.encoder_type == 'GradualStyleEncoder': + encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'Encoder4Editing': + encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) + else: + raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) + return encoder + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) + self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading encoders weights from irse50!') + encoder_ckpt = torch.load(model_paths['ir_se50']) + self.encoder.load_state_dict(encoder_ckpt, strict=False) + print('Loading decoder weights from pretrained!') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=False) + self.__load_latent_avg(ckpt, repeat=self.encoder.style_count) + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + if codes.ndim == 2: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] + else: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None diff --git a/e4e/models/stylegan2/__init__.py b/e4e/models/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/models/stylegan2/model.py b/e4e/models/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb12af85669ab6fd7f79cb14ddbdf80b2fbd83d --- /dev/null +++ b/e4e/models/stylegan2/model.py @@ -0,0 +1,678 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +if torch.cuda.is_available(): + from op.fused_act import FusedLeakyReLU, fused_leaky_relu + from op.upfirdn2d import upfirdn2d +else: + from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu + from op.upfirdn2d_cpu import upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + elif return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/e4e/models/stylegan2/op/__init__.py b/e4e/models/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/models/stylegan2/op/fused_act.py b/e4e/models/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..973a84fffde53668d31397da5fb993bbc95f7be0 --- /dev/null +++ b/e4e/models/stylegan2/op/fused_act.py @@ -0,0 +1,85 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/e4e/models/stylegan2/op/fused_bias_act.cpp b/e4e/models/stylegan2/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/e4e/models/stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/e4e/models/stylegan2/op/fused_bias_act_kernel.cu b/e4e/models/stylegan2/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/e4e/models/stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/e4e/models/stylegan2/op/upfirdn2d.cpp b/e4e/models/stylegan2/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/e4e/models/stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/e4e/models/stylegan2/op/upfirdn2d.py b/e4e/models/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc5a1e331c2bbb1893ac748cfd0f144ff0651b4 --- /dev/null +++ b/e4e/models/stylegan2/op/upfirdn2d.py @@ -0,0 +1,184 @@ +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, ::down_y, ::down_x, :] diff --git a/e4e/models/stylegan2/op/upfirdn2d_kernel.cu b/e4e/models/stylegan2/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e --- /dev/null +++ b/e4e/models/stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/e4e/notebooks/images/car_img.jpg b/e4e/notebooks/images/car_img.jpg new file mode 100644 index 0000000000000000000000000000000000000000..162d13ddc3a7496a160925098fa9bb31d42cfd2a Binary files /dev/null and b/e4e/notebooks/images/car_img.jpg differ diff --git a/e4e/notebooks/images/church_img.jpg b/e4e/notebooks/images/church_img.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2282837b5406496f9fd3180dde8b58b288ab88cd Binary files /dev/null and b/e4e/notebooks/images/church_img.jpg differ diff --git a/e4e/notebooks/images/horse_img.jpg b/e4e/notebooks/images/horse_img.jpg new file mode 100644 index 0000000000000000000000000000000000000000..510f4b98169528fe0d03b03683907baa3dcb0ca2 Binary files /dev/null and b/e4e/notebooks/images/horse_img.jpg differ diff --git a/e4e/notebooks/images/input_img.jpg b/e4e/notebooks/images/input_img.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6534b669166946d63c5468f71a18b502eba7efb3 Binary files /dev/null and b/e4e/notebooks/images/input_img.jpg differ diff --git a/e4e/options/__init__.py b/e4e/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/options/train_options.py b/e4e/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..583ea1423fdc9a649cd7044d74d554bf0ac2bf51 --- /dev/null +++ b/e4e/options/train_options.py @@ -0,0 +1,84 @@ +from argparse import ArgumentParser +from configs.paths_config import model_paths + + +class TrainOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, + help='Type of dataset/experiment to run') + self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use') + + self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training') + self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') + self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') + self.parser.add_argument('--test_workers', default=2, type=int, + help='Number of test/inference dataloader workers') + + self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate') + self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') + self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') + self.parser.add_argument('--start_from_latent_avg', action='store_true', + help='Whether to add average latent vector to generate codes from encoder.') + self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone') + + self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor') + self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') + self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor') + + self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, + help='Path to StyleGAN model weights') + self.parser.add_argument('--stylegan_size', default=1024, type=int, + help='size of pretrained StyleGAN Generator') + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') + + self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') + self.parser.add_argument('--image_interval', default=100, type=int, + help='Interval for logging train images during training') + self.parser.add_argument('--board_interval', default=50, type=int, + help='Interval for logging metrics to tensorboard') + self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') + self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') + + # Discriminator flags + self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier') + self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate') + self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization") + self.parser.add_argument("--d_reg_every", type=int, default=16, + help="interval for applying r1 regularization") + self.parser.add_argument('--use_w_pool', action='store_true', + help='Whether to store a latnet codes pool for the discriminator\'s training') + self.parser.add_argument("--w_pool_size", type=int, default=50, + help="W\'s pool size, depends on --use_w_pool") + + # e4e specific + self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas") + self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss") + + # Progressive training + self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None, + help="The training steps of training new deltas. steps[i] starts the delta_i training") + self.parser.add_argument('--progressive_start', type=int, default=None, + help="The training step to start training the deltas, overrides progressive_steps") + self.parser.add_argument('--progressive_step_every', type=int, default=2_000, + help="Amount of training steps for each progressive step") + + # Save additional training info to enable future training continuation from produced checkpoints + self.parser.add_argument('--save_training_data', action='store_true', + help='Save intermediate training data to resume training from the checkpoint') + self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory') + self.parser.add_argument('--keep_optimizer', action='store_true', + help='Whether to continue from the checkpoint\'s optimizer') + self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str, + help='Path to training checkpoint, works when --save_training_data was set to True') + self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None, + help="Name of training parameters to update the loaded training checkpoint") + + def parse(self): + opts = self.parser.parse_args() + return opts diff --git a/e4e/scripts/calc_losses_on_images.py b/e4e/scripts/calc_losses_on_images.py new file mode 100644 index 0000000000000000000000000000000000000000..32b6bcee854da7ae357daf82bd986f30db9fb72c --- /dev/null +++ b/e4e/scripts/calc_losses_on_images.py @@ -0,0 +1,87 @@ +from argparse import ArgumentParser +import os +import json +import sys +from tqdm import tqdm +import numpy as np +import torch +from torch.utils.data import DataLoader +import torchvision.transforms as transforms + +sys.path.append(".") +sys.path.append("..") + +from criteria.lpips.lpips import LPIPS +from datasets.gt_res_dataset import GTResDataset + + +def parse_args(): + parser = ArgumentParser(add_help=False) + parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) + parser.add_argument('--data_path', type=str, default='results') + parser.add_argument('--gt_path', type=str, default='gt_images') + parser.add_argument('--workers', type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--is_cars', action='store_true') + args = parser.parse_args() + return args + + +def run(args): + resize_dims = (256, 256) + if args.is_cars: + resize_dims = (192, 256) + transform = transforms.Compose([transforms.Resize(resize_dims), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + + print('Loading dataset') + dataset = GTResDataset(root_path=args.data_path, + gt_dir=args.gt_path, + transform=transform) + + dataloader = DataLoader(dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=int(args.workers), + drop_last=True) + + if args.mode == 'lpips': + loss_func = LPIPS(net_type='alex') + elif args.mode == 'l2': + loss_func = torch.nn.MSELoss() + else: + raise Exception('Not a valid mode!') + loss_func.cuda() + + global_i = 0 + scores_dict = {} + all_scores = [] + for result_batch, gt_batch in tqdm(dataloader): + for i in range(args.batch_size): + loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) + all_scores.append(loss) + im_path = dataset.pairs[global_i][0] + scores_dict[os.path.basename(im_path)] = loss + global_i += 1 + + all_scores = list(scores_dict.values()) + mean = np.mean(all_scores) + std = np.std(all_scores) + result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) + print('Finished with ', args.data_path) + print(result_str) + + out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') + if not os.path.exists(out_path): + os.makedirs(out_path) + + with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: + f.write(result_str) + with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: + json.dump(scores_dict, f) + + +if __name__ == '__main__': + args = parse_args() + run(args) diff --git a/e4e/scripts/inference.py b/e4e/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..185b9b34db85dcd97b9793bd5dbfc9d1ca046549 --- /dev/null +++ b/e4e/scripts/inference.py @@ -0,0 +1,133 @@ +import argparse + +import torch +import numpy as np +import sys +import os +import dlib + +sys.path.append(".") +sys.path.append("..") + +from configs import data_configs, paths_config +from datasets.inference_dataset import InferenceDataset +from torch.utils.data import DataLoader +from utils.model_utils import setup_model +from utils.common import tensor2im +from utils.alignment import align_face +from PIL import Image + + +def main(args): + net, opts = setup_model(args.ckpt, device) + is_cars = 'cars_' in opts.dataset_type + generator = net.decoder + generator.eval() + args, data_loader = setup_data_loader(args, opts) + + # Check if latents exist + latents_file_path = os.path.join(args.save_dir, 'latents.pt') + if os.path.exists(latents_file_path): + latent_codes = torch.load(latents_file_path).to(device) + else: + latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars) + torch.save(latent_codes, latents_file_path) + + if not args.latents_only: + generate_inversions(args, generator, latent_codes, is_cars=is_cars) + + +def setup_data_loader(args, opts): + dataset_args = data_configs.DATASETS[opts.dataset_type] + transforms_dict = dataset_args['transforms'](opts).get_transforms() + images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root'] + print(f"images path: {images_path}") + align_function = None + if args.align: + align_function = run_alignment + test_dataset = InferenceDataset(root=images_path, + transform=transforms_dict['transform_test'], + preprocess=align_function, + opts=opts) + + data_loader = DataLoader(test_dataset, + batch_size=args.batch, + shuffle=False, + num_workers=2, + drop_last=True) + + print(f'dataset length: {len(test_dataset)}') + + if args.n_sample is None: + args.n_sample = len(test_dataset) + return args, data_loader + + +def get_latents(net, x, is_cars=False): + codes = net.encoder(x) + if net.opts.start_from_latent_avg: + if codes.ndim == 2: + codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] + else: + codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1) + if codes.shape[1] == 18 and is_cars: + codes = codes[:, :16, :] + return codes + + +def get_all_latents(net, data_loader, n_images=None, is_cars=False): + all_latents = [] + i = 0 + with torch.no_grad(): + for batch in data_loader: + if n_images is not None and i > n_images: + break + x = batch + inputs = x.to(device).float() + latents = get_latents(net, inputs, is_cars) + all_latents.append(latents) + i += len(latents) + return torch.cat(all_latents) + + +def save_image(img, save_dir, idx): + result = tensor2im(img) + im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg") + Image.fromarray(np.array(result)).save(im_save_path) + + +@torch.no_grad() +def generate_inversions(args, g, latent_codes, is_cars): + print('Saving inversion images') + inversions_directory_path = os.path.join(args.save_dir, 'inversions') + os.makedirs(inversions_directory_path, exist_ok=True) + for i in range(args.n_sample): + imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True) + if is_cars: + imgs = imgs[:, :, 64:448, :] + save_image(imgs[0], inversions_directory_path, i + 1) + + +def run_alignment(image_path): + predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor']) + aligned_image = align_face(filepath=image_path, predictor=predictor) + print("Aligned image has shape: {}".format(aligned_image.size)) + return aligned_image + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser(description="Inference") + parser.add_argument("--images_dir", type=str, default=None, + help="The directory of the images to be inverted") + parser.add_argument("--save_dir", type=str, default=None, + help="The directory to save the latent codes and inversion images. (default: images_dir") + parser.add_argument("--batch", type=int, default=1, help="batch size for the generator") + parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.") + parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory") + parser.add_argument("--align", action="store_true", help="align face images before inference") + parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint") + + args = parser.parse_args() + main(args) diff --git a/e4e/scripts/train.py b/e4e/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d885cfde49a0b21140e663e475918698d5e51ee3 --- /dev/null +++ b/e4e/scripts/train.py @@ -0,0 +1,88 @@ +""" +This file runs the main training/val loop +""" +import os +import json +import math +import sys +import pprint +import torch +from argparse import Namespace + +sys.path.append(".") +sys.path.append("..") + +from options.train_options import TrainOptions +from training.coach import Coach + + +def main(): + opts = TrainOptions().parse() + previous_train_ckpt = None + if opts.resume_training_from_ckpt: + opts, previous_train_ckpt = load_train_checkpoint(opts) + else: + setup_progressive_steps(opts) + create_initial_experiment_dir(opts) + + coach = Coach(opts, previous_train_ckpt) + coach.train() + + +def load_train_checkpoint(opts): + train_ckpt_path = opts.resume_training_from_ckpt + previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu') + new_opts_dict = vars(opts) + opts = previous_train_ckpt['opts'] + opts['resume_training_from_ckpt'] = train_ckpt_path + update_new_configs(opts, new_opts_dict) + pprint.pprint(opts) + opts = Namespace(**opts) + if opts.sub_exp_dir is not None: + sub_exp_dir = opts.sub_exp_dir + opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir) + create_initial_experiment_dir(opts) + return opts, previous_train_ckpt + + +def setup_progressive_steps(opts): + log_size = int(math.log(opts.stylegan_size, 2)) + num_style_layers = 2*log_size - 2 + num_deltas = num_style_layers - 1 + if opts.progressive_start is not None: # If progressive delta training + opts.progressive_steps = [0] + next_progressive_step = opts.progressive_start + for i in range(num_deltas): + opts.progressive_steps.append(next_progressive_step) + next_progressive_step += opts.progressive_step_every + + assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ + "Invalid progressive training input" + + +def is_valid_progressive_steps(opts, num_style_layers): + return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 + + +def create_initial_experiment_dir(opts): + if os.path.exists(opts.exp_dir): + raise Exception('Oops... {} already exists'.format(opts.exp_dir)) + os.makedirs(opts.exp_dir) + + opts_dict = vars(opts) + pprint.pprint(opts_dict) + with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: + json.dump(opts_dict, f, indent=4, sort_keys=True) + + +def update_new_configs(ckpt_opts, new_opts): + for k, v in new_opts.items(): + if k not in ckpt_opts: + ckpt_opts[k] = v + if new_opts['update_param_list']: + for param in new_opts['update_param_list']: + ckpt_opts[param] = new_opts[param] + + +if __name__ == '__main__': + main() diff --git a/e4e/training/__init__.py b/e4e/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/training/coach.py b/e4e/training/coach.py new file mode 100644 index 0000000000000000000000000000000000000000..4c99da79e699c9362e02c289cd1425848d331d0b --- /dev/null +++ b/e4e/training/coach.py @@ -0,0 +1,437 @@ +import os +import random +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.use('Agg') + +import torch +from torch import nn, autograd +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F + +from utils import common, train_utils +from criteria import id_loss, moco_loss +from configs import data_configs +from datasets.images_dataset import ImagesDataset +from criteria.lpips.lpips import LPIPS +from models.psp import pSp +from models.latent_codes_pool import LatentCodesPool +from models.discriminator import LatentCodesDiscriminator +from models.encoders.psp_encoders import ProgressiveStage +from training.ranger import Ranger + +random.seed(0) +torch.manual_seed(0) + + +class Coach: + def __init__(self, opts, prev_train_checkpoint=None): + self.opts = opts + + self.global_step = 0 + + self.device = 'cuda:0' + self.opts.device = self.device + # Initialize network + self.net = pSp(self.opts).to(self.device) + + # Initialize loss + if self.opts.lpips_lambda > 0: + self.lpips_loss = LPIPS(net_type=self.opts.lpips_type).to(self.device).eval() + if self.opts.id_lambda > 0: + if 'ffhq' in self.opts.dataset_type or 'celeb' in self.opts.dataset_type: + self.id_loss = id_loss.IDLoss().to(self.device).eval() + else: + self.id_loss = moco_loss.MocoLoss(opts).to(self.device).eval() + self.mse_loss = nn.MSELoss().to(self.device).eval() + + # Initialize optimizer + self.optimizer = self.configure_optimizers() + + # Initialize discriminator + if self.opts.w_discriminator_lambda > 0: + self.discriminator = LatentCodesDiscriminator(512, 4).to(self.device) + self.discriminator_optimizer = torch.optim.Adam(list(self.discriminator.parameters()), + lr=opts.w_discriminator_lr) + self.real_w_pool = LatentCodesPool(self.opts.w_pool_size) + self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size) + + # Initialize dataset + self.train_dataset, self.test_dataset = self.configure_datasets() + self.train_dataloader = DataLoader(self.train_dataset, + batch_size=self.opts.batch_size, + shuffle=True, + num_workers=int(self.opts.workers), + drop_last=True) + self.test_dataloader = DataLoader(self.test_dataset, + batch_size=self.opts.test_batch_size, + shuffle=False, + num_workers=int(self.opts.test_workers), + drop_last=True) + + # Initialize logger + log_dir = os.path.join(opts.exp_dir, 'logs') + os.makedirs(log_dir, exist_ok=True) + self.logger = SummaryWriter(log_dir=log_dir) + + # Initialize checkpoint dir + self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_val_loss = None + if self.opts.save_interval is None: + self.opts.save_interval = self.opts.max_steps + + if prev_train_checkpoint is not None: + self.load_from_train_checkpoint(prev_train_checkpoint) + prev_train_checkpoint = None + + def load_from_train_checkpoint(self, ckpt): + print('Loading previous training data...') + self.global_step = ckpt['global_step'] + 1 + self.best_val_loss = ckpt['best_val_loss'] + self.net.load_state_dict(ckpt['state_dict']) + + if self.opts.keep_optimizer: + self.optimizer.load_state_dict(ckpt['optimizer']) + if self.opts.w_discriminator_lambda > 0: + self.discriminator.load_state_dict(ckpt['discriminator_state_dict']) + self.discriminator_optimizer.load_state_dict(ckpt['discriminator_optimizer_state_dict']) + if self.opts.progressive_steps: + self.check_for_progressive_training_update(is_resume_from_ckpt=True) + print(f'Resuming training from step {self.global_step}') + + def train(self): + self.net.train() + if self.opts.progressive_steps: + self.check_for_progressive_training_update() + while self.global_step < self.opts.max_steps: + for batch_idx, batch in enumerate(self.train_dataloader): + loss_dict = {} + if self.is_training_discriminator(): + loss_dict = self.train_discriminator(batch) + x, y, y_hat, latent = self.forward(batch) + loss, encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) + loss_dict = {**loss_dict, **encoder_loss_dict} + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Logging related + if self.global_step % self.opts.image_interval == 0 or ( + self.global_step < 1000 and self.global_step % 25 == 0): + self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces') + if self.global_step % self.opts.board_interval == 0: + self.print_metrics(loss_dict, prefix='train') + self.log_metrics(loss_dict, prefix='train') + + # Validation related + val_loss_dict = None + if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: + val_loss_dict = self.validate() + if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): + self.best_val_loss = val_loss_dict['loss'] + self.checkpoint_me(val_loss_dict, is_best=True) + + if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: + if val_loss_dict is not None: + self.checkpoint_me(val_loss_dict, is_best=False) + else: + self.checkpoint_me(loss_dict, is_best=False) + + if self.global_step == self.opts.max_steps: + print('OMG, finished training!') + break + + self.global_step += 1 + if self.opts.progressive_steps: + self.check_for_progressive_training_update() + + def check_for_progressive_training_update(self, is_resume_from_ckpt=False): + for i in range(len(self.opts.progressive_steps)): + if is_resume_from_ckpt and self.global_step >= self.opts.progressive_steps[i]: # Case checkpoint + self.net.encoder.set_progressive_stage(ProgressiveStage(i)) + if self.global_step == self.opts.progressive_steps[i]: # Case training reached progressive step + self.net.encoder.set_progressive_stage(ProgressiveStage(i)) + + def validate(self): + self.net.eval() + agg_loss_dict = [] + for batch_idx, batch in enumerate(self.test_dataloader): + cur_loss_dict = {} + if self.is_training_discriminator(): + cur_loss_dict = self.validate_discriminator(batch) + with torch.no_grad(): + x, y, y_hat, latent = self.forward(batch) + loss, cur_encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) + cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict} + agg_loss_dict.append(cur_loss_dict) + + # Logging related + self.parse_and_log_images(id_logs, x, y, y_hat, + title='images/test/faces', + subscript='{:04d}'.format(batch_idx)) + + # For first step just do sanity test on small amount of data + if self.global_step == 0 and batch_idx >= 4: + self.net.train() + return None # Do not log, inaccurate in first batch + + loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) + self.log_metrics(loss_dict, prefix='test') + self.print_metrics(loss_dict, prefix='test') + + self.net.train() + return loss_dict + + def checkpoint_me(self, loss_dict, is_best): + save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) + save_dict = self.__get_save_dict() + checkpoint_path = os.path.join(self.checkpoint_dir, save_name) + torch.save(save_dict, checkpoint_path) + with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: + if is_best: + f.write( + '**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) + else: + f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) + + def configure_optimizers(self): + params = list(self.net.encoder.parameters()) + if self.opts.train_decoder: + params += list(self.net.decoder.parameters()) + else: + self.requires_grad(self.net.decoder, False) + if self.opts.optim_name == 'adam': + optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) + else: + optimizer = Ranger(params, lr=self.opts.learning_rate) + return optimizer + + def configure_datasets(self): + if self.opts.dataset_type not in data_configs.DATASETS.keys(): + Exception('{} is not a valid dataset_type'.format(self.opts.dataset_type)) + print('Loading dataset for {}'.format(self.opts.dataset_type)) + dataset_args = data_configs.DATASETS[self.opts.dataset_type] + transforms_dict = dataset_args['transforms'](self.opts).get_transforms() + train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'], + target_root=dataset_args['train_target_root'], + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_gt_train'], + opts=self.opts) + test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'], + target_root=dataset_args['test_target_root'], + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_test'], + opts=self.opts) + print("Number of training samples: {}".format(len(train_dataset))) + print("Number of test samples: {}".format(len(test_dataset))) + return train_dataset, test_dataset + + def calc_loss(self, x, y, y_hat, latent): + loss_dict = {} + loss = 0.0 + id_logs = None + if self.is_training_discriminator(): # Adversarial loss + loss_disc = 0. + dims_to_discriminate = self.get_dims_to_discriminate() if self.is_progressive_training() else \ + list(range(self.net.decoder.n_latent)) + + for i in dims_to_discriminate: + w = latent[:, i, :] + fake_pred = self.discriminator(w) + loss_disc += F.softplus(-fake_pred).mean() + loss_disc /= len(dims_to_discriminate) + loss_dict['encoder_discriminator_loss'] = float(loss_disc) + loss += self.opts.w_discriminator_lambda * loss_disc + + if self.opts.progressive_steps and self.net.encoder.progressive_stage.value != 18: # delta regularization loss + total_delta_loss = 0 + deltas_latent_dims = self.net.encoder.get_deltas_starting_dimensions() + + first_w = latent[:, 0, :] + for i in range(1, self.net.encoder.progressive_stage.value + 1): + curr_dim = deltas_latent_dims[i] + delta = latent[:, curr_dim, :] - first_w + delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean() + loss_dict[f"delta{i}_loss"] = float(delta_loss) + total_delta_loss += delta_loss + loss_dict['total_delta_loss'] = float(total_delta_loss) + loss += self.opts.delta_norm_lambda * total_delta_loss + + if self.opts.id_lambda > 0: # Similarity loss + loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x) + loss_dict['loss_id'] = float(loss_id) + loss_dict['id_improve'] = float(sim_improvement) + loss += loss_id * self.opts.id_lambda + if self.opts.l2_lambda > 0: + loss_l2 = F.mse_loss(y_hat, y) + loss_dict['loss_l2'] = float(loss_l2) + loss += loss_l2 * self.opts.l2_lambda + if self.opts.lpips_lambda > 0: + loss_lpips = self.lpips_loss(y_hat, y) + loss_dict['loss_lpips'] = float(loss_lpips) + loss += loss_lpips * self.opts.lpips_lambda + loss_dict['loss'] = float(loss) + return loss, loss_dict, id_logs + + def forward(self, batch): + x, y = batch + x, y = x.to(self.device).float(), y.to(self.device).float() + y_hat, latent = self.net.forward(x, return_latents=True) + if self.opts.dataset_type == "cars_encode": + y_hat = y_hat[:, :, 32:224, :] + return x, y, y_hat, latent + + def log_metrics(self, metrics_dict, prefix): + for key, value in metrics_dict.items(): + self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step) + + def print_metrics(self, metrics_dict, prefix): + print('Metrics for {}, step {}'.format(prefix, self.global_step)) + for key, value in metrics_dict.items(): + print('\t{} = '.format(key), value) + + def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2): + im_data = [] + for i in range(display_count): + cur_im_data = { + 'input_face': common.log_input_image(x[i], self.opts), + 'target_face': common.tensor2im(y[i]), + 'output_face': common.tensor2im(y_hat[i]), + } + if id_logs is not None: + for key in id_logs[i]: + cur_im_data[key] = id_logs[i][key] + im_data.append(cur_im_data) + self.log_images(title, im_data=im_data, subscript=subscript) + + def log_images(self, name, im_data, subscript=None, log_latest=False): + fig = common.vis_faces(im_data) + step = self.global_step + if log_latest: + step = 0 + if subscript: + path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step)) + else: + path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step)) + os.makedirs(os.path.dirname(path), exist_ok=True) + fig.savefig(path) + plt.close(fig) + + def __get_save_dict(self): + save_dict = { + 'state_dict': self.net.state_dict(), + 'opts': vars(self.opts) + } + # save the latent avg in state_dict for inference if truncation of w was used during training + if self.opts.start_from_latent_avg: + save_dict['latent_avg'] = self.net.latent_avg + + if self.opts.save_training_data: # Save necessary information to enable training continuation from checkpoint + save_dict['global_step'] = self.global_step + save_dict['optimizer'] = self.optimizer.state_dict() + save_dict['best_val_loss'] = self.best_val_loss + if self.opts.w_discriminator_lambda > 0: + save_dict['discriminator_state_dict'] = self.discriminator.state_dict() + save_dict['discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict() + return save_dict + + def get_dims_to_discriminate(self): + deltas_starting_dimensions = self.net.encoder.get_deltas_starting_dimensions() + return deltas_starting_dimensions[:self.net.encoder.progressive_stage.value + 1] + + def is_progressive_training(self): + return self.opts.progressive_steps is not None + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Discriminator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + def is_training_discriminator(self): + return self.opts.w_discriminator_lambda > 0 + + @staticmethod + def discriminator_loss(real_pred, fake_pred, loss_dict): + real_loss = F.softplus(-real_pred).mean() + fake_loss = F.softplus(fake_pred).mean() + + loss_dict['d_real_loss'] = float(real_loss) + loss_dict['d_fake_loss'] = float(fake_loss) + + return real_loss + fake_loss + + @staticmethod + def discriminator_r1_loss(real_pred, real_w): + grad_real, = autograd.grad( + outputs=real_pred.sum(), inputs=real_w, create_graph=True + ) + grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() + + return grad_penalty + + @staticmethod + def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + def train_discriminator(self, batch): + loss_dict = {} + x, _ = batch + x = x.to(self.device).float() + self.requires_grad(self.discriminator, True) + + with torch.no_grad(): + real_w, fake_w = self.sample_real_and_fake_latents(x) + real_pred = self.discriminator(real_w) + fake_pred = self.discriminator(fake_w) + loss = self.discriminator_loss(real_pred, fake_pred, loss_dict) + loss_dict['discriminator_loss'] = float(loss) + + self.discriminator_optimizer.zero_grad() + loss.backward() + self.discriminator_optimizer.step() + + # r1 regularization + d_regularize = self.global_step % self.opts.d_reg_every == 0 + if d_regularize: + real_w = real_w.detach() + real_w.requires_grad = True + real_pred = self.discriminator(real_w) + r1_loss = self.discriminator_r1_loss(real_pred, real_w) + + self.discriminator.zero_grad() + r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[0] + r1_final_loss.backward() + self.discriminator_optimizer.step() + loss_dict['discriminator_r1_loss'] = float(r1_final_loss) + + # Reset to previous state + self.requires_grad(self.discriminator, False) + + return loss_dict + + def validate_discriminator(self, test_batch): + with torch.no_grad(): + loss_dict = {} + x, _ = test_batch + x = x.to(self.device).float() + real_w, fake_w = self.sample_real_and_fake_latents(x) + real_pred = self.discriminator(real_w) + fake_pred = self.discriminator(fake_w) + loss = self.discriminator_loss(real_pred, fake_pred, loss_dict) + loss_dict['discriminator_loss'] = float(loss) + return loss_dict + + def sample_real_and_fake_latents(self, x): + sample_z = torch.randn(self.opts.batch_size, 512, device=self.device) + real_w = self.net.decoder.get_latent(sample_z) + fake_w = self.net.encoder(x) + if self.is_progressive_training(): # When progressive training, feed only unique w's + dims_to_discriminate = self.get_dims_to_discriminate() + fake_w = fake_w[:, dims_to_discriminate, :] + if self.opts.use_w_pool: + real_w = self.real_w_pool.query(real_w) + fake_w = self.fake_w_pool.query(fake_w) + if fake_w.ndim == 3: + fake_w = fake_w[:, 0, :] + return real_w, fake_w diff --git a/e4e/training/ranger.py b/e4e/training/ranger.py new file mode 100644 index 0000000000000000000000000000000000000000..3d63264dda6df0ee40cac143440f0b5f8977a9ad --- /dev/null +++ b/e4e/training/ranger.py @@ -0,0 +1,164 @@ +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. + +import math +import torch +from torch.optim.optimizer import Optimizer + + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + def __setstate__(self, state): + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss \ No newline at end of file diff --git a/e4e/utils/__init__.py b/e4e/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/e4e/utils/alignment.py b/e4e/utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..a02798f0f7c9fdcc319f7884a491b9e6580cc8aa --- /dev/null +++ b/e4e/utils/alignment.py @@ -0,0 +1,115 @@ +import numpy as np +import PIL +import PIL.Image +import scipy +import scipy.ndimage +import dlib + + +def get_landmark(filepath, predictor): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + + for k, d in enumerate(dets): + shape = predictor(img, d) + + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + return lm + + +def align_face(filepath, predictor): + """ + :param filepath: str + :return: PIL Image + """ + + lm = get_landmark(filepath, predictor) + + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = PIL.Image.open(filepath) + + output_size = 256 + transform_size = 256 + enable_padding = True + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Return aligned image. + return img diff --git a/e4e/utils/common.py b/e4e/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b19e18ddcb78b06678fa18e4a76da44fc511b789 --- /dev/null +++ b/e4e/utils/common.py @@ -0,0 +1,55 @@ +from PIL import Image +import matplotlib.pyplot as plt + + +# Log images +def log_input_image(x, opts): + return tensor2im(x) + + +def tensor2im(var): + # var shape: (3, H, W) + var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() + var = ((var + 1) / 2) + var[var < 0] = 0 + var[var > 1] = 1 + var = var * 255 + return Image.fromarray(var.astype('uint8')) + + +def vis_faces(log_hooks): + display_count = len(log_hooks) + fig = plt.figure(figsize=(8, 4 * display_count)) + gs = fig.add_gridspec(display_count, 3) + for i in range(display_count): + hooks_dict = log_hooks[i] + fig.add_subplot(gs[i, 0]) + if 'diff_input' in hooks_dict: + vis_faces_with_id(hooks_dict, fig, gs, i) + else: + vis_faces_no_id(hooks_dict, fig, gs, i) + plt.tight_layout() + return fig + + +def vis_faces_with_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face']) + plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), + float(hooks_dict['diff_target']))) + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) + + +def vis_faces_no_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face'], cmap="gray") + plt.title('Input') + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target') + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output') diff --git a/e4e/utils/data_utils.py b/e4e/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ba79f4a2d5cc2b97dce76d87bf6e7cdebbc257 --- /dev/null +++ b/e4e/utils/data_utils.py @@ -0,0 +1,25 @@ +""" +Code adopted from pix2pixHD: +https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py +""" +import os + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images diff --git a/e4e/utils/model_utils.py b/e4e/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e51e95578f72b3218d6d832e3b604193cb68c1d7 --- /dev/null +++ b/e4e/utils/model_utils.py @@ -0,0 +1,35 @@ +import torch +import argparse +from models.psp import pSp +from models.encoders.psp_encoders import Encoder4Editing + + +def setup_model(checkpoint_path, device='cuda'): + ckpt = torch.load(checkpoint_path, map_location='cpu') + opts = ckpt['opts'] + + opts['checkpoint_path'] = checkpoint_path + opts['device'] = device + opts = argparse.Namespace(**opts) + + net = pSp(opts) + net.eval() + net = net.to(device) + return net, opts + + +def load_e4e_standalone(checkpoint_path, device='cuda'): + ckpt = torch.load(checkpoint_path, map_location='cpu') + opts = argparse.Namespace(**ckpt['opts']) + e4e = Encoder4Editing(50, 'ir_se', opts) + e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')} + e4e.load_state_dict(e4e_dict) + e4e.eval() + e4e = e4e.to(device) + latent_avg = ckpt['latent_avg'].to(device) + + def add_latent_avg(model, inputs, outputs): + return outputs + latent_avg.repeat(outputs.shape[0], 1, 1) + + e4e.register_forward_hook(add_latent_avg) + return e4e diff --git a/e4e/utils/train_utils.py b/e4e/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0 --- /dev/null +++ b/e4e/utils/train_utils.py @@ -0,0 +1,13 @@ + +def aggregate_loss_dict(agg_loss_dict): + mean_vals = {} + for output in agg_loss_dict: + for key in output: + mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] + for key in mean_vals: + if len(mean_vals[key]) > 0: + mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) + else: + print('{} has no value'.format(key)) + mean_vals[key] = 0 + return mean_vals diff --git a/e4e_projection.py b/e4e_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5bc1b7301e068626460b3ac23fe44f49238d79 --- /dev/null +++ b/e4e_projection.py @@ -0,0 +1,38 @@ +import os +import sys +import numpy as np +from PIL import Image +import torch +import torchvision.transforms as transforms +from argparse import Namespace +from e4e.models.psp import pSp +from util import * + + + +@ torch.no_grad() +def projection(img, name, device='cuda'): + + + model_path = 'e4e_ffhq_encode.pt' + ckpt = torch.load(model_path, map_location='cpu') + opts = ckpt['opts'] + opts['checkpoint_path'] = model_path + opts= Namespace(**opts) + net = pSp(opts, device).eval().to(device) + + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + img = transform(img).unsqueeze(0).to(device) + images, w_plus = net(img, randomize_noise=False, return_latents=True) + result_file = {} + result_file['latent'] = w_plus[0] + torch.save(result_file, name) + return w_plus[0] diff --git a/elon.png b/elon.png new file mode 100644 index 0000000000000000000000000000000000000000..272bbbceff04d64c6eabd3f99c25350095d3c33e Binary files /dev/null and b/elon.png differ diff --git a/iu.jpeg b/iu.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..d89550f47ff9e351739770bb1561b320a3232cef Binary files /dev/null and b/iu.jpeg differ diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..497bf78d57c54d58cd3b55f26c718be2470a04f1 --- /dev/null +++ b/model.py @@ -0,0 +1,688 @@ +import math +import random +import functools +import operator + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + +from op import conv2d_gradfix +if torch.cuda.is_available(): + from op.fused_act import FusedLeakyReLU, fused_leaky_relu + from op.upfirdn2d import upfirdn2d +else: + from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu + from op.upfirdn2d_cpu import upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = conv2d_gradfix.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" + ) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + fused=True, + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + self.fused = fused + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " + f"upsample={self.upsample}, downsample={self.downsample})" + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + if not self.fused: + weight = self.scale * self.weight.squeeze(0) + style = self.modulation(style) + + if self.demodulate: + w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) + dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() + + input = input * style.reshape(batch, in_channel, 1, 1) + + if self.upsample: + weight = weight.transpose(0, 1) + out = conv2d_gradfix.conv_transpose2d( + input, weight, padding=0, stride=2 + ) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) + + else: + out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) + + if self.demodulate: + out = out * dcoefs.view(batch, -1, 1, 1) + + return out + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = conv2d_gradfix.conv_transpose2d( + input, weight, padding=0, stride=2, groups=batch + ) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = conv2d_gradfix.conv2d( + input, weight, padding=0, stride=2, groups=batch + ) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = conv2d_gradfix.conv2d( + input, weight, padding=self.padding, groups=batch + ) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + @torch.no_grad() + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + @torch.no_grad() + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) + ] + + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + latent = styles[0].unsqueeze(1).repeat(1, self.n_latent, 1) + else: + latent = styles + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + return image + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + layers.append(FusedLeakyReLU(out_channel, bias=bias)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + diff --git a/mona.png b/mona.png new file mode 100644 index 0000000000000000000000000000000000000000..95c5c5a09c73b343cd2d1911816267db3ed79619 Binary files /dev/null and b/mona.png differ diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..ace9fff330d5ce42b11d0e450d7cde99ccecfa77 --- /dev/null +++ b/packages.txt @@ -0,0 +1,4 @@ +ffmpeg +libsm6 +libxext6 +cmake \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0379116faae6fe42b29fd9ea44800c6893fac9e1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +tqdm +gdown +scikit-learn==0.22 +scipy +lpips +opencv-python-headless +torch +torchvision +imageio +dlib \ No newline at end of file diff --git a/util.py b/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7f1124405c955bdb46b0bd2cec68d47653d100 --- /dev/null +++ b/util.py @@ -0,0 +1,220 @@ +from matplotlib import pyplot as plt +import torch +import torch.nn.functional as F +import os +import cv2 +import dlib +from PIL import Image +import numpy as np +import math +import torchvision +import scipy +import scipy.ndimage +import torchvision.transforms as transforms + +from huggingface_hub import hf_hub_download + + +shape_predictor_path = hf_hub_download(repo_id="akhaliq/jojogan_dlib", filename="shape_predictor_68_face_landmarks.dat") + + +google_drive_paths = { + "models/stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK", + "models/dlibshape_predictor_68_face_landmarks.dat": "https://drive.google.com/uc?id=11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp", + "models/e4e_ffhq_encode.pt": "https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7", + "models/restyle_psp_ffhq_encode.pt": "https://drive.google.com/uc?id=1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd", + "models/arcane_caitlyn.pt": "https://drive.google.com/uc?id=1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc", + "models/arcane_caitlyn_preserve_color.pt": "https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH", + "models/arcane_jinx_preserve_color.pt": "https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney", + "models/arcane_jinx.pt": "https://drive.google.com/uc?id=1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_", + "models/disney.pt": "https://drive.google.com/uc?id=1zbE2upakFUAx8ximYnLofFwfT8MilqJA", + "models/disney_preserve_color.pt": "https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi", + "models/jojo.pt": "https://drive.google.com/uc?id=13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4", + "models/jojo_preserve_color.pt": "https://drive.google.com/uc?id=1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2", + "models/jojo_yasuho.pt": "https://drive.google.com/uc?id=1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_", + "models/jojo_yasuho_preserve_color.pt": "https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L", + "models/supergirl.pt": "https://drive.google.com/uc?id=1L0y9IYgzLNzB-33xTpXpecsKU-t9DpVC", + "models/supergirl_preserve_color.pt": "https://drive.google.com/uc?id=1VmKGuvThWHym7YuayXxjv0fSn32lfDpE", +} + +@torch.no_grad() +def load_model(generator, model_file_path): + ensure_checkpoint_exists(model_file_path) + ckpt = torch.load(model_file_path, map_location=lambda storage, loc: storage) + generator.load_state_dict(ckpt["g_ema"], strict=False) + return generator.mean_latent(50000) + +def ensure_checkpoint_exists(model_weights_filename): + if not os.path.isfile(model_weights_filename) and ( + model_weights_filename in google_drive_paths + ): + gdrive_url = google_drive_paths[model_weights_filename] + try: + from gdown import download as drive_download + + drive_download(gdrive_url, model_weights_filename, quiet=False) + except ModuleNotFoundError: + print( + "gdown module not found.", + "pip3 install gdown or, manually download the checkpoint file:", + gdrive_url + ) + + if not os.path.isfile(model_weights_filename) and ( + model_weights_filename not in google_drive_paths + ): + print( + model_weights_filename, + " not found, you may need to manually download the model weights." + ) + +# given a list of filenames, load the inverted style code +@torch.no_grad() +def load_source(files, generator, device='cuda'): + sources = [] + + for file in files: + source = torch.load(f'./inversion_codes/{file}.pt')['latent'].to(device) + + if source.size(0) != 1: + source = source.unsqueeze(0) + + if source.ndim == 3: + source = generator.get_latent(source, truncation=1, is_latent=True) + source = list2style(source) + + sources.append(source) + + sources = torch.cat(sources, 0) + if type(sources) is not list: + sources = style2list(sources) + + return sources + +def display_image(image, size=None, mode='nearest', unnorm=False, title=''): + # image is [3,h,w] or [1,3,h,w] tensor [0,1] + if not isinstance(image, torch.Tensor): + image = transforms.ToTensor()(image).unsqueeze(0) + if image.is_cuda: + image = image.cpu() + if size is not None and image.size(-1) != size: + image = F.interpolate(image, size=(size,size), mode=mode) + if image.dim() == 4: + image = image[0] + image = image.permute(1, 2, 0).detach().numpy() + plt.figure() + plt.title(title) + plt.axis('off') + plt.imshow(image) + +def get_landmark(filepath, predictor): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + assert len(dets) > 0, "Face not detected, try another face image" + + for k, d in enumerate(dets): + shape = predictor(img, d) + + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + return lm + + +def align_face(filepath, output_size=256, transform_size=1024, enable_padding=True): + + """ + :param filepath: str + :return: PIL Image + """ + predictor = dlib.shape_predictor(shape_predictor_path) + lm = get_landmark(filepath, predictor) + + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = Image.open(filepath) + + transform_size = output_size + enable_padding = True + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) + if output_size < transform_size: + img = img.resize((output_size, output_size), Image.ANTIALIAS) + + # Return aligned image. + return img + +def strip_path_extension(path): + return os.path.splitext(path)[0]