aliabd commited on
Commit
bca104a
β€’
1 Parent(s): c775144

copied all files from repo

Browse files
Files changed (12) hide show
  1. LICENSE +21 -0
  2. dataset.py +167 -0
  3. distributed.py +126 -0
  4. gradiodemo.py +84 -0
  5. inference.ipynb +0 -0
  6. inference_colab.ipynb +0 -0
  7. model.py +757 -0
  8. requirements.txt +10 -0
  9. teaser.gif +0 -0
  10. teaser.png +0 -0
  11. train.py +458 -0
  12. util.py +161 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Min Jin Chong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
dataset.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+
5
+ import os
6
+ import os.path
7
+ from io import BytesIO
8
+
9
+ import lmdb
10
+ from torch.utils.data import Dataset
11
+
12
+ class MultiResolutionDataset(Dataset):
13
+ def __init__(self, path, transform, resolution=256):
14
+ self.env = lmdb.open(
15
+ path,
16
+ max_readers=32,
17
+ readonly=True,
18
+ lock=False,
19
+ readahead=False,
20
+ meminit=False,
21
+ )
22
+
23
+ if not self.env:
24
+ raise IOError('Cannot open lmdb dataset', path)
25
+
26
+ with self.env.begin(write=False) as txn:
27
+ self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
28
+
29
+ self.resolution = resolution
30
+ self.transform = transform
31
+
32
+ def __len__(self):
33
+ return self.length
34
+
35
+ def __getitem__(self, index):
36
+ with self.env.begin(write=False) as txn:
37
+ key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
38
+ img_bytes = txn.get(key)
39
+
40
+ buffer = BytesIO(img_bytes)
41
+ img = Image.open(buffer)
42
+ img = self.transform(img)
43
+
44
+ return img
45
+
46
+
47
+ def has_file_allowed_extension(filename, extensions):
48
+ """Checks if a file is an allowed extension.
49
+
50
+ Args:
51
+ filename (string): path to a file
52
+
53
+ Returns:
54
+ bool: True if the filename ends with a known image extension
55
+ """
56
+ filename_lower = filename.lower()
57
+ return any(filename_lower.endswith(ext) for ext in extensions)
58
+
59
+
60
+ def find_classes(dir):
61
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
62
+ classes.sort()
63
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
64
+ return classes, class_to_idx
65
+
66
+
67
+ def make_dataset(dir, extensions):
68
+ images = []
69
+ for root, _, fnames in sorted(os.walk(dir)):
70
+ for fname in sorted(fnames):
71
+ if has_file_allowed_extension(fname, extensions):
72
+ path = os.path.join(root, fname)
73
+ item = (path, 0)
74
+ images.append(item)
75
+
76
+ return images
77
+
78
+
79
+ class DatasetFolder(data.Dataset):
80
+ def __init__(self, root, loader, extensions, transform=None, target_transform=None):
81
+ # classes, class_to_idx = find_classes(root)
82
+ samples = make_dataset(root, extensions)
83
+ if len(samples) == 0:
84
+ raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
85
+ "Supported extensions are: " + ",".join(extensions)))
86
+
87
+ self.root = root
88
+ self.loader = loader
89
+ self.extensions = extensions
90
+ self.samples = samples
91
+
92
+ self.transform = transform
93
+ self.target_transform = target_transform
94
+
95
+ def __getitem__(self, index):
96
+ """
97
+ Args:
98
+ index (int): Index
99
+
100
+ Returns:
101
+ tuple: (sample, target) where target is class_index of the target class.
102
+ """
103
+ path, target = self.samples[index]
104
+ sample = self.loader(path)
105
+ if self.transform is not None:
106
+ sample = self.transform(sample)
107
+ if self.target_transform is not None:
108
+ target = self.target_transform(target)
109
+
110
+ return sample
111
+
112
+ def __len__(self):
113
+ return len(self.samples)
114
+
115
+ def __repr__(self):
116
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
117
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
118
+ fmt_str += ' Root Location: {}\n'.format(self.root)
119
+ tmp = ' Transforms (if any): '
120
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
121
+ tmp = ' Target Transforms (if any): '
122
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
123
+ return fmt_str
124
+
125
+
126
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
127
+
128
+
129
+ def pil_loader(path):
130
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
131
+ with open(path, 'rb') as f:
132
+ img = Image.open(f)
133
+ return img.convert('RGB')
134
+
135
+
136
+ def default_loader(path):
137
+ return pil_loader(path)
138
+
139
+
140
+ class ImageFolder(DatasetFolder):
141
+ def __init__(self, root, transform1=None, transform2=None, target_transform=None,
142
+ loader=default_loader):
143
+ super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
144
+ transform=transform1,
145
+ target_transform=target_transform)
146
+ self.imgs = self.samples
147
+ self.transform2 = transform2
148
+
149
+ def set_stage(self, stage):
150
+ if stage == 'last':
151
+ self.transform = self.transform2
152
+
153
+ class ListFolder(Dataset):
154
+ def __init__(self, txt, transform):
155
+ with open(txt) as f:
156
+ imgpaths= f.readlines()
157
+ self.imgpaths = [x.strip() for x in imgpaths]
158
+ self.transform = transform
159
+
160
+ def __getitem__(self, idx):
161
+ path = self.imgpaths[idx]
162
+ image = Image.open(path)
163
+ return self.transform(image)
164
+
165
+ def __len__(self):
166
+ return len(self.imgpaths)
167
+
distributed.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import distributed as dist
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+
9
+ def get_rank():
10
+ if not dist.is_available():
11
+ return 0
12
+
13
+ if not dist.is_initialized():
14
+ return 0
15
+
16
+ return dist.get_rank()
17
+
18
+
19
+ def synchronize():
20
+ if not dist.is_available():
21
+ return
22
+
23
+ if not dist.is_initialized():
24
+ return
25
+
26
+ world_size = dist.get_world_size()
27
+
28
+ if world_size == 1:
29
+ return
30
+
31
+ dist.barrier()
32
+
33
+
34
+ def get_world_size():
35
+ if not dist.is_available():
36
+ return 1
37
+
38
+ if not dist.is_initialized():
39
+ return 1
40
+
41
+ return dist.get_world_size()
42
+
43
+
44
+ def reduce_sum(tensor):
45
+ if not dist.is_available():
46
+ return tensor
47
+
48
+ if not dist.is_initialized():
49
+ return tensor
50
+
51
+ tensor = tensor.clone()
52
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53
+
54
+ return tensor
55
+
56
+
57
+ def gather_grad(params):
58
+ world_size = get_world_size()
59
+
60
+ if world_size == 1:
61
+ return
62
+
63
+ for param in params:
64
+ if param.grad is not None:
65
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
+ param.grad.data.div_(world_size)
67
+
68
+
69
+ def all_gather(data):
70
+ world_size = get_world_size()
71
+
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ buffer = pickle.dumps(data)
76
+ storage = torch.ByteStorage.from_buffer(buffer)
77
+ tensor = torch.ByteTensor(storage).to('cuda')
78
+
79
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
+ dist.all_gather(size_list, local_size)
82
+ size_list = [int(size.item()) for size in size_list]
83
+ max_size = max(size_list)
84
+
85
+ tensor_list = []
86
+ for _ in size_list:
87
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
+
89
+ if local_size != max_size:
90
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
+ tensor = torch.cat((tensor, padding), 0)
92
+
93
+ dist.all_gather(tensor_list, tensor)
94
+
95
+ data_list = []
96
+
97
+ for size, tensor in zip(size_list, tensor_list):
98
+ buffer = tensor.cpu().numpy().tobytes()[:size]
99
+ data_list.append(pickle.loads(buffer))
100
+
101
+ return data_list
102
+
103
+
104
+ def reduce_loss_dict(loss_dict):
105
+ world_size = get_world_size()
106
+
107
+ if world_size < 2:
108
+ return loss_dict
109
+
110
+ with torch.no_grad():
111
+ keys = []
112
+ losses = []
113
+
114
+ for k in sorted(loss_dict.keys()):
115
+ keys.append(k)
116
+ losses.append(loss_dict[k])
117
+
118
+ losses = torch.stack(losses, 0)
119
+ dist.reduce(losses, dst=0)
120
+
121
+ if dist.get_rank() == 0:
122
+ losses /= world_size
123
+
124
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
125
+
126
+ return reduced_losses
gradiodemo.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.utils import data
7
+ from torchvision import transforms, utils
8
+ from tqdm import tqdm
9
+ torch.backends.cudnn.benchmark = True
10
+ import copy
11
+ from util import *
12
+ from PIL import Image
13
+
14
+ from model import *
15
+ import moviepy.video.io.ImageSequenceClip
16
+ import scipy
17
+ import kornia.augmentation as K
18
+
19
+ from base64 import b64encode
20
+ import gradio as gr
21
+ from torchvision import transforms
22
+
23
+ torch.hub.download_url_to_file('https://i.imgur.com/HiOTPNg.png', 'mona.png')
24
+ torch.hub.download_url_to_file('https://i.imgur.com/Cw8HcTN.png', 'painting.png')
25
+
26
+ device = 'cpu'
27
+ latent_dim = 8
28
+ n_mlp = 5
29
+ num_down = 3
30
+
31
+ G_A2B = Generator(256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=.01,n_res=1).to(device).eval()
32
+
33
+ ensure_checkpoint_exists('GNR_checkpoint.pt')
34
+ ckpt = torch.load('GNR_checkpoint.pt', map_location=device)
35
+
36
+ G_A2B.load_state_dict(ckpt['G_A2B_ema'])
37
+
38
+ # mean latent
39
+ truncation = 1
40
+ with torch.no_grad():
41
+ mean_style = G_A2B.mapping(torch.randn([1000, latent_dim]).to(device)).mean(0, keepdim=True)
42
+
43
+
44
+ test_transform = transforms.Compose([
45
+ transforms.Resize((256, 256)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
48
+ ])
49
+ plt.rcParams['figure.dpi'] = 200
50
+
51
+ torch.manual_seed(84986)
52
+
53
+ num_styles = 1
54
+ style = torch.randn([num_styles, latent_dim]).to(device)
55
+
56
+
57
+ def inference(input_im):
58
+ real_A = test_transform(input_im).unsqueeze(0).to(device)
59
+
60
+ with torch.no_grad():
61
+ A2B_content, _ = G_A2B.encode(real_A)
62
+ fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), style)
63
+ std=(0.5, 0.5, 0.5)
64
+ mean=(0.5, 0.5, 0.5)
65
+ z = fake_A2B * torch.tensor(std).view(3, 1, 1)
66
+ z = z + torch.tensor(mean).view(3, 1, 1)
67
+ tensor_to_pil = transforms.ToPILImage(mode='RGB')(z.squeeze())
68
+ return tensor_to_pil
69
+
70
+ title = "GANsNRoses"
71
+ description = "demo for GANsNRoses. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
72
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.06561'>GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)</a> | <a href='https://github.com/mchong6/GANsNRoses'>Github Repo</a></p>"
73
+
74
+ gr.Interface(
75
+ inference,
76
+ [gr.inputs.Image(type="pil", label="Input")],
77
+ gr.outputs.Image(type="pil", label="Output"),
78
+ title=title,
79
+ description=description,
80
+ article=article,
81
+ examples=[
82
+ ["mona.png"],
83
+ ["painting.png"]
84
+ ]).launch()
inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
inference_colab.ipynb ADDED
The diff for this file is too large to render. See raw diff
model.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import math
3
+ import random
4
+ import functools
5
+ import operator
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+
12
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
13
+ n_latent = 11
14
+
15
+
16
+ channels = {
17
+ 4: 512,
18
+ 8: 512,
19
+ 16: 512,
20
+ 32: 512,
21
+ 64: 256,
22
+ 128: 128,
23
+ 256: 64,
24
+ 512: 32,
25
+ 1024: 16,
26
+ }
27
+
28
+ class LambdaLR():
29
+ def __init__(self, n_epochs, offset, decay_start_epoch):
30
+ assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
31
+ self.n_epochs = n_epochs
32
+ self.offset = offset
33
+ self.decay_start_epoch = decay_start_epoch
34
+
35
+ def step(self, epoch):
36
+ return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)
37
+
38
+
39
+ class PixelNorm(nn.Module):
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ def forward(self, input):
44
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
45
+
46
+ def make_kernel(k):
47
+ k = torch.tensor(k, dtype=torch.float32)
48
+
49
+ if k.ndim == 1:
50
+ k = k[None, :] * k[:, None]
51
+
52
+ k /= k.sum()
53
+
54
+ return k
55
+
56
+ class Upsample(nn.Module):
57
+ def __init__(self, kernel, factor=2):
58
+ super().__init__()
59
+
60
+ self.factor = factor
61
+ kernel = make_kernel(kernel) * (factor ** 2)
62
+ self.register_buffer('kernel', kernel)
63
+
64
+ p = kernel.shape[0] - factor
65
+
66
+ pad0 = (p + 1) // 2 + factor - 1
67
+ pad1 = p // 2
68
+
69
+ self.pad = (pad0, pad1)
70
+
71
+ def forward(self, input):
72
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
73
+
74
+ return out
75
+
76
+
77
+ class Downsample(nn.Module):
78
+ def __init__(self, kernel, factor=2):
79
+ super().__init__()
80
+
81
+ self.factor = factor
82
+ kernel = make_kernel(kernel)
83
+ self.register_buffer('kernel', kernel)
84
+
85
+ p = kernel.shape[0] - factor
86
+
87
+ pad0 = (p + 1) // 2
88
+ pad1 = p // 2
89
+
90
+ self.pad = (pad0, pad1)
91
+
92
+ def forward(self, input):
93
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
94
+
95
+ return out
96
+
97
+
98
+ class Blur(nn.Module):
99
+ def __init__(self, kernel, pad, upsample_factor=1):
100
+ super().__init__()
101
+
102
+ kernel = make_kernel(kernel)
103
+
104
+ if upsample_factor > 1:
105
+ kernel = kernel * (upsample_factor ** 2)
106
+
107
+ self.register_buffer('kernel', kernel)
108
+
109
+ self.pad = pad
110
+
111
+ def forward(self, input):
112
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
113
+
114
+ return out
115
+
116
+
117
+ class EqualConv2d(nn.Module):
118
+ def __init__(
119
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
120
+ ):
121
+ super().__init__()
122
+
123
+ self.weight = nn.Parameter(
124
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
125
+ )
126
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
127
+
128
+ self.stride = stride
129
+ self.padding = padding
130
+
131
+ if bias:
132
+ self.bias = nn.Parameter(torch.zeros(out_channel))
133
+
134
+ else:
135
+ self.bias = None
136
+
137
+ def forward(self, input):
138
+ out = F.conv2d(
139
+ input,
140
+ self.weight * self.scale,
141
+ bias=self.bias,
142
+ stride=self.stride,
143
+ padding=self.padding,
144
+ )
145
+
146
+ return out
147
+
148
+ def __repr__(self):
149
+ return (
150
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
151
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
152
+ )
153
+
154
+
155
+ class EqualLinear(nn.Module):
156
+ def __init__(
157
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
158
+ ):
159
+ super().__init__()
160
+
161
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
162
+
163
+ if bias:
164
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
165
+
166
+ else:
167
+ self.bias = None
168
+
169
+ self.activation = activation
170
+
171
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
172
+ self.lr_mul = lr_mul
173
+
174
+ def forward(self, input):
175
+ bias = self.bias*self.lr_mul if self.bias is not None else None
176
+ if self.activation:
177
+ out = F.linear(input, self.weight * self.scale)
178
+ out = fused_leaky_relu(out, bias)
179
+
180
+ else:
181
+ out = F.linear(
182
+ input, self.weight * self.scale, bias=bias
183
+ )
184
+
185
+ return out
186
+
187
+ def __repr__(self):
188
+ return (
189
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
190
+ )
191
+
192
+
193
+ class ScaledLeakyReLU(nn.Module):
194
+ def __init__(self, negative_slope=0.2):
195
+ super().__init__()
196
+
197
+ self.negative_slope = negative_slope
198
+
199
+ def forward(self, input):
200
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
201
+
202
+ return out * math.sqrt(2)
203
+
204
+
205
+ class ModulatedConv2d(nn.Module):
206
+ def __init__(
207
+ self,
208
+ in_channel,
209
+ out_channel,
210
+ kernel_size,
211
+ style_dim,
212
+ use_style=True,
213
+ demodulate=True,
214
+ upsample=False,
215
+ downsample=False,
216
+ blur_kernel=[1, 3, 3, 1],
217
+ ):
218
+ super().__init__()
219
+
220
+ self.eps = 1e-8
221
+ self.kernel_size = kernel_size
222
+ self.in_channel = in_channel
223
+ self.out_channel = out_channel
224
+ self.upsample = upsample
225
+ self.downsample = downsample
226
+ self.use_style = use_style
227
+
228
+ if upsample:
229
+ factor = 2
230
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
231
+ pad0 = (p + 1) // 2 + factor - 1
232
+ pad1 = p // 2 + 1
233
+
234
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
235
+
236
+ if downsample:
237
+ factor = 2
238
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
239
+ pad0 = (p + 1) // 2
240
+ pad1 = p // 2
241
+
242
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
243
+
244
+ fan_in = in_channel * kernel_size ** 2
245
+ self.scale = 1 / math.sqrt(fan_in)
246
+ self.padding = kernel_size // 2
247
+
248
+ self.weight = nn.Parameter(
249
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
250
+ )
251
+
252
+ if use_style:
253
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
254
+ else:
255
+ self.modulation = nn.Parameter(torch.Tensor(1, 1, in_channel, 1, 1).fill_(1))
256
+
257
+ self.demodulate = demodulate
258
+
259
+ def __repr__(self):
260
+ return (
261
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
262
+ f'upsample={self.upsample}, downsample={self.downsample})'
263
+ )
264
+
265
+ def forward(self, input, style):
266
+ batch, in_channel, height, width = input.shape
267
+
268
+ if self.use_style:
269
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
270
+ weight = self.scale * self.weight * style
271
+ else:
272
+ weight = self.scale * self.weight.expand(batch,-1,-1,-1,-1) * self.modulation
273
+
274
+ if self.demodulate:
275
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
276
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
277
+
278
+ weight = weight.view(
279
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
280
+ )
281
+
282
+ if self.upsample:
283
+ input = input.view(1, batch * in_channel, height, width)
284
+ weight = weight.view(
285
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
286
+ )
287
+ weight = weight.transpose(1, 2).reshape(
288
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
289
+ )
290
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
291
+ _, _, height, width = out.shape
292
+ out = out.view(batch, self.out_channel, height, width)
293
+ out = self.blur(out)
294
+
295
+ elif self.downsample:
296
+ input = self.blur(input)
297
+ _, _, height, width = input.shape
298
+ input = input.view(1, batch * in_channel, height, width)
299
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
300
+ _, _, height, width = out.shape
301
+ out = out.view(batch, self.out_channel, height, width)
302
+
303
+ else:
304
+ input = input.view(1, batch * in_channel, height, width)
305
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
306
+ _, _, height, width = out.shape
307
+ out = out.view(batch, self.out_channel, height, width)
308
+
309
+ return out
310
+
311
+
312
+ class NoiseInjection(nn.Module):
313
+ def __init__(self):
314
+ super().__init__()
315
+
316
+ self.weight = nn.Parameter(torch.zeros(1))
317
+
318
+ def forward(self, image, noise=None):
319
+ if noise is None:
320
+ batch, _, height, width = image.shape
321
+ noise = image.new_empty(batch, 1, height, width).normal_()
322
+
323
+ return image + self.weight * noise
324
+
325
+
326
+ class ConstantInput(nn.Module):
327
+ def __init__(self, style_dim):
328
+ super().__init__()
329
+
330
+ self.input = nn.Parameter(torch.randn(1, style_dim))
331
+
332
+ def forward(self, input):
333
+ batch = input.shape[0]
334
+ out = self.input.repeat(batch, n_latent)
335
+
336
+ return out
337
+
338
+
339
+ class StyledConv(nn.Module):
340
+ def __init__(
341
+ self,
342
+ in_channel,
343
+ out_channel,
344
+ kernel_size,
345
+ style_dim,
346
+ use_style=True,
347
+ upsample=False,
348
+ downsample=False,
349
+ blur_kernel=[1, 3, 3, 1],
350
+ demodulate=True,
351
+ ):
352
+ super().__init__()
353
+ self.use_style = use_style
354
+
355
+ self.conv = ModulatedConv2d(
356
+ in_channel,
357
+ out_channel,
358
+ kernel_size,
359
+ style_dim,
360
+ use_style=use_style,
361
+ upsample=upsample,
362
+ downsample=downsample,
363
+ blur_kernel=blur_kernel,
364
+ demodulate=demodulate,
365
+ )
366
+
367
+ #if use_style:
368
+ # self.noise = NoiseInjection()
369
+ #else:
370
+ # self.noise = None
371
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
372
+ # self.activate = ScaledLeakyReLU(0.2)
373
+ self.activate = FusedLeakyReLU(out_channel)
374
+
375
+ def forward(self, input, style=None, noise=None):
376
+ out = self.conv(input, style)
377
+ #if self.use_style:
378
+ # out = self.noise(out, noise=noise)
379
+ # out = out + self.bias
380
+ out = self.activate(out)
381
+
382
+ return out
383
+
384
+
385
+ class StyledResBlock(nn.Module):
386
+ def __init__(self, in_channel, style_dim, blur_kernel=[1, 3, 3, 1], demodulate=True):
387
+ super().__init__()
388
+
389
+ self.conv1 = StyledConv(in_channel, in_channel, 3, style_dim, upsample=False, blur_kernel=blur_kernel, demodulate=demodulate)
390
+ self.conv2 = StyledConv(in_channel, in_channel, 3, style_dim, upsample=False, blur_kernel=blur_kernel, demodulate=demodulate)
391
+
392
+ def forward(self, input, style):
393
+ out = self.conv1(input, style)
394
+ out = self.conv2(out, style)
395
+ out = (out + input) / math.sqrt(2)
396
+
397
+ return out
398
+
399
+ class ToRGB(nn.Module):
400
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
401
+ super().__init__()
402
+
403
+ if upsample:
404
+ self.upsample = Upsample(blur_kernel)
405
+
406
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
407
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
408
+
409
+ def forward(self, input, style, skip=None):
410
+ out = self.conv(input, style)
411
+ out = out + self.bias
412
+
413
+ if skip is not None:
414
+ skip = self.upsample(skip)
415
+
416
+ out = out + skip
417
+
418
+ return out
419
+
420
+
421
+ class Generator(nn.Module):
422
+ def __init__(
423
+ self,
424
+ size,
425
+ num_down,
426
+ latent_dim,
427
+ n_mlp,
428
+ n_res,
429
+ channel_multiplier=1,
430
+ blur_kernel=[1, 3, 3, 1],
431
+ lr_mlp=0.01,
432
+ ):
433
+ super().__init__()
434
+ self.size = size
435
+
436
+ style_dim = 512
437
+
438
+ mapping = [EqualLinear(latent_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu')]
439
+ for i in range(n_mlp-1):
440
+ mapping.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'))
441
+
442
+ self.mapping = nn.Sequential(*mapping)
443
+
444
+ self.encoder = Encoder(size, latent_dim, num_down, n_res, channel_multiplier)
445
+
446
+ self.log_size = int(math.log(size, 2)) #7
447
+ in_log_size = self.log_size - num_down #7-2 or 7-3
448
+ in_size = 2 ** in_log_size
449
+
450
+ in_channel = channels[in_size]
451
+ self.adain_bottleneck = nn.ModuleList()
452
+ for i in range(n_res):
453
+ self.adain_bottleneck.append(StyledResBlock(in_channel, style_dim))
454
+
455
+ self.conv1 = StyledConv(in_channel, in_channel, 3, style_dim, blur_kernel=blur_kernel)
456
+ self.to_rgb1 = ToRGB(in_channel, style_dim, upsample=False)
457
+
458
+ self.num_layers = (self.log_size - in_log_size) * 2 + 1 #7
459
+
460
+ self.convs = nn.ModuleList()
461
+ self.upsamples = nn.ModuleList()
462
+ self.to_rgbs = nn.ModuleList()
463
+ #self.noises = nn.Module()
464
+
465
+
466
+ #for layer_idx in range(self.num_layers):
467
+ # res = (layer_idx + (in_log_size*2+1)) // 2 #2,3,3,5 ... -> 4,5,5,6 ...
468
+ # shape = [1, 1, 2 ** res, 2 ** res]
469
+ # self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
470
+
471
+ for i in range(in_log_size+1, self.log_size + 1):
472
+ out_channel = channels[2 ** i]
473
+
474
+ self.convs.append(
475
+ StyledConv(
476
+ in_channel,
477
+ out_channel,
478
+ 3,
479
+ style_dim,
480
+ upsample=True,
481
+ blur_kernel=blur_kernel,
482
+ )
483
+ )
484
+
485
+ self.convs.append(
486
+ StyledConv(
487
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
488
+ )
489
+ )
490
+
491
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
492
+
493
+ in_channel = out_channel
494
+
495
+ def style_encode(self, input):
496
+ return self.encoder(input)[1]
497
+
498
+ def encode(self, input):
499
+ return self.encoder(input)
500
+
501
+ def forward(self, input, z=None):
502
+ content, style = self.encode(input)
503
+ if z is None:
504
+ out = self.decode(content, style)
505
+ else:
506
+ out = self.decode(content, z)
507
+
508
+ return out, content, style
509
+
510
+ def decode(self, input, styles, use_mapping=True):
511
+ if use_mapping:
512
+ styles = self.mapping(styles)
513
+ #styles = styles.repeat(1, n_latent).view(styles.size(0), n_latent, -1)
514
+ out = input
515
+ i = 0
516
+ for conv in self.adain_bottleneck:
517
+ out = conv(out, styles)
518
+ i += 1
519
+
520
+ out = self.conv1(out, styles, noise=None)
521
+ skip = self.to_rgb1(out, styles)
522
+ i += 2
523
+
524
+ for conv1, conv2, to_rgb in zip(
525
+ self.convs[::2], self.convs[1::2], self.to_rgbs
526
+ ):
527
+ out = conv1(out, styles, noise=None)
528
+ out = conv2(out, styles, noise=None)
529
+ skip = to_rgb(out, styles, skip)
530
+
531
+ i += 3
532
+
533
+ image = skip
534
+ return image
535
+
536
+ class ConvLayer(nn.Sequential):
537
+ def __init__(
538
+ self,
539
+ in_channel,
540
+ out_channel,
541
+ kernel_size,
542
+ downsample=False,
543
+ blur_kernel=[1, 3, 3, 1],
544
+ bias=True,
545
+ activate=True,
546
+ ):
547
+ layers = []
548
+
549
+ if downsample:
550
+ factor = 2
551
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
552
+ pad0 = (p + 1) // 2
553
+ pad1 = p // 2
554
+
555
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
556
+
557
+ stride = 2
558
+ self.padding = 0
559
+
560
+ else:
561
+ stride = 1
562
+ self.padding = kernel_size // 2
563
+
564
+ layers.append(
565
+ EqualConv2d(
566
+ in_channel,
567
+ out_channel,
568
+ kernel_size,
569
+ padding=self.padding,
570
+ stride=stride,
571
+ bias=bias and not activate,
572
+ )
573
+ )
574
+
575
+ if activate:
576
+ if bias:
577
+ layers.append(FusedLeakyReLU(out_channel))
578
+
579
+ else:
580
+ layers.append(ScaledLeakyReLU(0.2))
581
+
582
+ super().__init__(*layers)
583
+
584
+ class InResBlock(nn.Module):
585
+ def __init__(self, in_channel, blur_kernel=[1, 3, 3, 1]):
586
+ super().__init__()
587
+
588
+ self.conv1 = StyledConv(in_channel, in_channel, 3, None, blur_kernel=blur_kernel, demodulate=True, use_style=False)
589
+ self.conv2 = StyledConv(in_channel, in_channel, 3, None, blur_kernel=blur_kernel, demodulate=True, use_style=False)
590
+
591
+ def forward(self, input):
592
+ out = self.conv1(input, None)
593
+ out = self.conv2(out, None)
594
+ out = (out + input) / math.sqrt(2)
595
+
596
+ return out
597
+
598
+ class ResBlock(nn.Module):
599
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True):
600
+ super().__init__()
601
+
602
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
603
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample)
604
+
605
+ if downsample or in_channel != out_channel:
606
+ self.skip = ConvLayer(
607
+ in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
608
+ )
609
+ else:
610
+ self.skip = None
611
+
612
+ def forward(self, input):
613
+ out = self.conv1(input)
614
+ out = self.conv2(out)
615
+
616
+ if self.skip is None:
617
+ skip = input
618
+ else:
619
+ skip = self.skip(input)
620
+ out = (out + skip) / math.sqrt(2)
621
+
622
+ return out
623
+
624
+ class Discriminator(nn.Module):
625
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
626
+ super().__init__()
627
+ self.size = size
628
+ l_branch = self.make_net_(32)
629
+ l_branch += [ConvLayer(channels[32], 1, 1, activate=False)]
630
+ self.l_branch = nn.Sequential(*l_branch)
631
+
632
+
633
+ g_branch = self.make_net_(8)
634
+ self.g_branch = nn.Sequential(*g_branch)
635
+ self.g_adv = ConvLayer(channels[8], 1, 1, activate=False)
636
+
637
+ self.g_std = nn.Sequential(ConvLayer(channels[8], channels[4], 3, downsample=True),
638
+ nn.Flatten(),
639
+ EqualLinear(channels[4] * 4 * 4, 128, activation='fused_lrelu'),
640
+ )
641
+ self.g_final = EqualLinear(128, 1, activation=False)
642
+
643
+
644
+ def make_net_(self, out_size):
645
+ size = self.size
646
+ convs = [ConvLayer(3, channels[size], 1)]
647
+ log_size = int(math.log(size, 2))
648
+ out_log_size = int(math.log(out_size, 2))
649
+ in_channel = channels[size]
650
+
651
+ for i in range(log_size, out_log_size, -1):
652
+ out_channel = channels[2 ** (i - 1)]
653
+ convs.append(ResBlock(in_channel, out_channel))
654
+ in_channel = out_channel
655
+
656
+ return convs
657
+
658
+ def forward(self, x):
659
+ l_adv = self.l_branch(x)
660
+
661
+ g_act = self.g_branch(x)
662
+ g_adv = self.g_adv(g_act)
663
+
664
+ output = self.g_std(g_act)
665
+ g_stddev = torch.sqrt(output.var(0, keepdim=True, unbiased=False) + 1e-8).repeat(x.size(0),1)
666
+ g_std = self.g_final(g_stddev)
667
+ return [l_adv, g_adv, g_std]
668
+
669
+
670
+
671
+ class Encoder(nn.Module):
672
+ def __init__(self, size, latent_dim, num_down, n_res, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
673
+ super().__init__()
674
+ stem = [ConvLayer(3, channels[size], 1)]
675
+ log_size = int(math.log(size, 2))
676
+ in_channel = channels[size]
677
+
678
+ for i in range(log_size, log_size-num_down, -1):
679
+ out_channel = channels[2 ** (i - 1)]
680
+ stem.append(ResBlock(in_channel, out_channel, downsample=True))
681
+ in_channel = out_channel
682
+ stem += [ResBlock(in_channel, in_channel, downsample=False) for i in range(n_res)]
683
+ self.stem = nn.Sequential(*stem)
684
+
685
+ self.content = nn.Sequential(
686
+ ConvLayer(in_channel, in_channel, 1),
687
+ ConvLayer(in_channel, in_channel, 1)
688
+ )
689
+ style = []
690
+ for i in range(log_size-num_down, 2, -1):
691
+ out_channel = channels[2 ** (i - 1)]
692
+ style.append(ConvLayer(in_channel, out_channel, 3, downsample=True))
693
+ in_channel = out_channel
694
+ style += [
695
+ nn.Flatten(),
696
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
697
+ EqualLinear(channels[4], latent_dim),
698
+ ]
699
+ self.style = nn.Sequential(*style)
700
+
701
+
702
+ def forward(self, input):
703
+ act = self.stem(input)
704
+ content = self.content(act)
705
+ style = self.style(act)
706
+ return content, style
707
+
708
+ class StyleEncoder(nn.Module):
709
+ def __init__(self, size, style_dim, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
710
+ super().__init__()
711
+ convs = [ConvLayer(3, channels[size], 1)]
712
+
713
+ log_size = int(math.log(size, 2))
714
+
715
+ in_channel = channels[size]
716
+ num_down = 6
717
+
718
+ for i in range(log_size, log_size-num_down, -1):
719
+ w = 2 ** (i - 1)
720
+ out_channel = channels[w]
721
+ convs.append(ConvLayer(in_channel, out_channel, 3, downsample=True))
722
+ in_channel = out_channel
723
+
724
+ convs += [
725
+ nn.Flatten(),
726
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), EqualLinear(channels[4], style_dim),
727
+ ]
728
+ self.convs = nn.Sequential(*convs)
729
+
730
+ def forward(self, input):
731
+ style = self.convs(input)
732
+ return style.view(input.size(0), -1)
733
+
734
+ class LatDiscriminator(nn.Module):
735
+ def __init__(self, style_dim):
736
+ super().__init__()
737
+
738
+ fc = [EqualLinear(style_dim, 256, activation='fused_lrelu')]
739
+ for i in range(3):
740
+ fc += [EqualLinear(256, 256, activation='fused_lrelu')]
741
+ fc += [FCMinibatchStd(256, 256)]
742
+ fc += [EqualLinear(256, 1)]
743
+ self.fc = nn.Sequential(*fc)
744
+
745
+ def forward(self, input):
746
+ return [self.fc(input), ]
747
+
748
+ class FCMinibatchStd(nn.Module):
749
+ def __init__(self, in_channel, out_channel):
750
+ super().__init__()
751
+ self.fc = EqualLinear(in_channel+1, out_channel, activation='fused_lrelu')
752
+
753
+ def forward(self, out):
754
+ stddev = torch.sqrt(out.var(0, unbiased=False) + 1e-8).mean().view(1,1).repeat(out.size(0), 1)
755
+ out = torch.cat([out, stddev], 1)
756
+ out = self.fc(out)
757
+ return out
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ gdown
3
+ kornia
4
+ scipy
5
+ opencv-python
6
+ moviepy
7
+ lpips
8
+ ninja
9
+ gradio
10
+ torchvision
teaser.gif ADDED
teaser.png ADDED
train.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import random
4
+ import os
5
+ from util import *
6
+ import numpy as np
7
+ import torch
8
+ torch.backends.cudnn.benchmark = True
9
+ from torch import nn, autograd
10
+ from torch import optim
11
+ from torch.nn import functional as F
12
+ from torch.utils import data
13
+ import torch.distributed as dist
14
+
15
+ from torchvision import transforms, utils
16
+ from tqdm import tqdm
17
+ from torch.optim import lr_scheduler
18
+ import copy
19
+ import kornia.augmentation as K
20
+ import kornia
21
+ import lpips
22
+
23
+ from model import *
24
+ from dataset import ImageFolder
25
+ from distributed import (
26
+ get_rank,
27
+ synchronize,
28
+ reduce_loss_dict,
29
+ reduce_sum,
30
+ get_world_size,
31
+ )
32
+
33
+ mse_criterion = nn.MSELoss()
34
+
35
+
36
+ def test(args, genA2B, genB2A, testA_loader, testB_loader, name, step):
37
+ testA_loader = iter(testA_loader)
38
+ testB_loader = iter(testB_loader)
39
+ with torch.no_grad():
40
+ test_sample_num = 16
41
+
42
+ genA2B.eval(), genB2A.eval()
43
+ A2B = []
44
+ B2A = []
45
+ for i in range(test_sample_num):
46
+ real_A = testA_loader.next()
47
+ real_B = testB_loader.next()
48
+
49
+ real_A, real_B = real_A.cuda(), real_B.cuda()
50
+
51
+ A2B_content, A2B_style = genA2B.encode(real_A)
52
+ B2A_content, B2A_style = genB2A.encode(real_B)
53
+
54
+ if i % 2 == 0:
55
+ A2B_mod1 = torch.randn([1, args.latent_dim]).cuda()
56
+ B2A_mod1 = torch.randn([1, args.latent_dim]).cuda()
57
+ A2B_mod2 = torch.randn([1, args.latent_dim]).cuda()
58
+ B2A_mod2 = torch.randn([1, args.latent_dim]).cuda()
59
+
60
+ fake_B2B, _, _ = genA2B(real_B)
61
+ fake_A2A, _, _ = genB2A(real_A)
62
+
63
+ colsA = [real_A, fake_A2A]
64
+ colsB = [real_B, fake_B2B]
65
+
66
+ fake_A2B_1 = genA2B.decode(A2B_content, A2B_mod1)
67
+ fake_B2A_1 = genB2A.decode(B2A_content, B2A_mod1)
68
+
69
+ fake_A2B_2 = genA2B.decode(A2B_content, A2B_mod2)
70
+ fake_B2A_2 = genB2A.decode(B2A_content, B2A_mod2)
71
+
72
+ fake_A2B_3 = genA2B.decode(A2B_content, B2A_style)
73
+ fake_B2A_3 = genB2A.decode(B2A_content, A2B_style)
74
+
75
+ colsA += [fake_A2B_3, fake_A2B_1, fake_A2B_2]
76
+ colsB += [fake_B2A_3, fake_B2A_1, fake_B2A_2]
77
+
78
+ fake_A2B2A, _, _ = genB2A(fake_A2B_3, A2B_style)
79
+ fake_B2A2B, _, _ = genA2B(fake_B2A_3, B2A_style)
80
+ colsA.append(fake_A2B2A)
81
+ colsB.append(fake_B2A2B)
82
+
83
+ fake_A2B2A, _, _ = genB2A(fake_A2B_1, A2B_style)
84
+ fake_B2A2B, _, _ = genA2B(fake_B2A_1, B2A_style)
85
+ colsA.append(fake_A2B2A)
86
+ colsB.append(fake_B2A2B)
87
+
88
+ fake_A2B2A, _, _ = genB2A(fake_A2B_2, A2B_style)
89
+ fake_B2A2B, _, _ = genA2B(fake_B2A_2, B2A_style)
90
+ colsA.append(fake_A2B2A)
91
+ colsB.append(fake_B2A2B)
92
+
93
+ fake_A2B2A, _, _ = genB2A(fake_A2B_1)
94
+ fake_B2A2B, _, _ = genA2B(fake_B2A_1)
95
+ colsA.append(fake_A2B2A)
96
+ colsB.append(fake_B2A2B)
97
+
98
+ colsA = torch.cat(colsA, 2).detach().cpu()
99
+ colsB = torch.cat(colsB, 2).detach().cpu()
100
+
101
+ A2B.append(colsA)
102
+ B2A.append(colsB)
103
+ A2B = torch.cat(A2B, 0)
104
+ B2A = torch.cat(B2A, 0)
105
+
106
+ utils.save_image(A2B, f'{im_path}/{name}_A2B_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16)
107
+ utils.save_image(B2A, f'{im_path}/{name}_B2A_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16)
108
+
109
+ genA2B.train(), genB2A.train()
110
+
111
+
112
+ def train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device):
113
+ G_A2B.train(), G_B2A.train(), D_A.train(), D_B.train()
114
+ trainA_loader = sample_data(trainA_loader)
115
+ trainB_loader = sample_data(trainB_loader)
116
+ G_scheduler = lr_scheduler.StepLR(G_optim, step_size=100000, gamma=0.5)
117
+ D_scheduler = lr_scheduler.StepLR(D_optim, step_size=100000, gamma=0.5)
118
+
119
+ pbar = range(args.iter)
120
+
121
+ if get_rank() == 0:
122
+ pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.1)
123
+
124
+ loss_dict = {}
125
+ mean_path_length_A2B = 0
126
+ mean_path_length_B2A = 0
127
+
128
+ if args.distributed:
129
+ G_A2B_module = G_A2B.module
130
+ G_B2A_module = G_B2A.module
131
+ D_A_module = D_A.module
132
+ D_B_module = D_B.module
133
+ D_L_module = D_L.module
134
+
135
+ else:
136
+ G_A2B_module = G_A2B
137
+ G_B2A_module = G_B2A
138
+ D_A_module = D_A
139
+ D_B_module = D_B
140
+ D_L_module = D_L
141
+
142
+ for idx in pbar:
143
+ i = idx + args.start_iter
144
+
145
+ if i > args.iter:
146
+ print('Done!')
147
+ break
148
+
149
+ ori_A = next(trainA_loader)
150
+ ori_B = next(trainB_loader)
151
+ if isinstance(ori_A, list):
152
+ ori_A = ori_A[0]
153
+ if isinstance(ori_B, list):
154
+ ori_B = ori_B[0]
155
+
156
+ ori_A = ori_A.to(device)
157
+ ori_B = ori_B.to(device)
158
+ aug_A = aug(ori_A)
159
+ aug_B = aug(ori_B)
160
+ A = aug(ori_A[[np.random.randint(args.batch)]].expand_as(ori_A))
161
+ B = aug(ori_B[[np.random.randint(args.batch)]].expand_as(ori_B))
162
+
163
+ if i % args.d_reg_every == 0:
164
+ aug_A.requires_grad = True
165
+ aug_B.requires_grad = True
166
+
167
+ A2B_content, A2B_style = G_A2B.encode(A)
168
+ B2A_content, B2A_style = G_B2A.encode(B)
169
+
170
+ # get new style
171
+ aug_A2B_style = G_B2A.style_encode(aug_B)
172
+ aug_B2A_style = G_A2B.style_encode(aug_A)
173
+ rand_A2B_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_()
174
+ rand_B2A_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_()
175
+
176
+ # styles
177
+ idx = torch.randperm(2*args.batch)
178
+ input_A2B_style = torch.cat([rand_A2B_style, aug_A2B_style], 0)[idx][:args.batch]
179
+
180
+ idx = torch.randperm(2*args.batch)
181
+ input_B2A_style = torch.cat([rand_B2A_style, aug_B2A_style], 0)[idx][:args.batch]
182
+
183
+ fake_A2B = G_A2B.decode(A2B_content, input_A2B_style)
184
+ fake_B2A = G_B2A.decode(B2A_content, input_B2A_style)
185
+
186
+
187
+ # train disc
188
+ real_A_logit = D_A(aug_A)
189
+ real_B_logit = D_B(aug_B)
190
+ real_L_logit1 = D_L(rand_A2B_style)
191
+ real_L_logit2 = D_L(rand_B2A_style)
192
+
193
+ fake_B_logit = D_B(fake_A2B.detach())
194
+ fake_A_logit = D_A(fake_B2A.detach())
195
+ fake_L_logit1 = D_L(aug_A2B_style.detach())
196
+ fake_L_logit2 = D_L(aug_B2A_style.detach())
197
+
198
+ # global loss
199
+ D_loss = d_logistic_loss(real_A_logit, fake_A_logit) +\
200
+ d_logistic_loss(real_B_logit, fake_B_logit) +\
201
+ d_logistic_loss(real_L_logit1, fake_L_logit1) +\
202
+ d_logistic_loss(real_L_logit2, fake_L_logit2)
203
+
204
+ loss_dict['D_adv'] = D_loss
205
+
206
+ if i % args.d_reg_every == 0:
207
+ r1_A_loss = d_r1_loss(real_A_logit, aug_A)
208
+ r1_B_loss = d_r1_loss(real_B_logit, aug_B)
209
+ r1_L_loss = d_r1_loss(real_L_logit1, rand_A2B_style) + d_r1_loss(real_L_logit2, rand_B2A_style)
210
+ r1_loss = r1_A_loss + r1_B_loss + r1_L_loss
211
+ D_r1_loss = (args.r1 / 2 * r1_loss * args.d_reg_every)
212
+ D_loss += D_r1_loss
213
+
214
+ D_optim.zero_grad()
215
+ D_loss.backward()
216
+ D_optim.step()
217
+
218
+ #Generator
219
+ # adv loss
220
+ fake_B_logit = D_B(fake_A2B)
221
+ fake_A_logit = D_A(fake_B2A)
222
+ fake_L_logit1 = D_L(aug_A2B_style)
223
+ fake_L_logit2 = D_L(aug_B2A_style)
224
+
225
+ lambda_adv = (1, 1, 1)
226
+ G_adv_loss = 1 * (g_nonsaturating_loss(fake_A_logit, lambda_adv) +\
227
+ g_nonsaturating_loss(fake_B_logit, lambda_adv) +\
228
+ 2*g_nonsaturating_loss(fake_L_logit1, (1,)) +\
229
+ 2*g_nonsaturating_loss(fake_L_logit2, (1,)))
230
+
231
+ # style consis loss
232
+ G_con_loss = 50 * (A2B_style.var(0, unbiased=False).sum() + B2A_style.var(0, unbiased=False).sum())
233
+
234
+ # cycle recon
235
+ A2B2A_content, A2B2A_style = G_B2A.encode(fake_A2B)
236
+ B2A2B_content, B2A2B_style = G_A2B.encode(fake_B2A)
237
+ fake_A2B2A = G_B2A.decode(A2B2A_content, shuffle_batch(A2B_style))
238
+ fake_B2A2B = G_A2B.decode(B2A2B_content, shuffle_batch(B2A_style))
239
+
240
+ G_cycle_loss = 20 * (F.mse_loss(fake_A2B2A, A) + F.mse_loss(fake_B2A2B, B))
241
+ lpips_loss = 10 * (lpips_fn(fake_A2B2A, A).mean() + lpips_fn(fake_B2A2B, B).mean()) #10 for anime
242
+
243
+ # style reconstruction
244
+ G_style_loss = 5 * (mse_criterion(A2B2A_style, input_A2B_style) +\
245
+ mse_criterion(B2A2B_style, input_B2A_style))
246
+
247
+
248
+ G_loss = G_adv_loss + G_cycle_loss + G_con_loss + lpips_loss + G_style_loss
249
+
250
+ loss_dict['G_adv'] = G_adv_loss
251
+ loss_dict['G_con'] = G_con_loss
252
+ loss_dict['G_cycle'] = G_cycle_loss
253
+ loss_dict['lpips'] = lpips_loss
254
+
255
+ G_optim.zero_grad()
256
+ G_loss.backward()
257
+ G_optim.step()
258
+
259
+ G_scheduler.step()
260
+ D_scheduler.step()
261
+
262
+ accumulate(G_A2B_ema, G_A2B_module)
263
+ accumulate(G_B2A_ema, G_B2A_module)
264
+
265
+ loss_reduced = reduce_loss_dict(loss_dict)
266
+ D_adv_loss_val = loss_reduced['D_adv'].mean().item()
267
+
268
+ G_adv_loss_val = loss_reduced['G_adv'].mean().item()
269
+ G_cycle_loss_val = loss_reduced['G_cycle'].mean().item()
270
+ G_con_loss_val = loss_reduced['G_con'].mean().item()
271
+ lpips_val = loss_reduced['lpips'].mean().item()
272
+
273
+ if get_rank() == 0:
274
+ pbar.set_description(
275
+ (
276
+ f'Dadv: {D_adv_loss_val:.2f}; lpips: {lpips_val:.2f} '
277
+ f'Gadv: {G_adv_loss_val:.2f}; Gcycle: {G_cycle_loss_val:.2f}; GMS: {G_con_loss_val:.2f} {G_style_loss.item():.2f}'
278
+ )
279
+ )
280
+
281
+ if i % 1000 == 0:
282
+ with torch.no_grad():
283
+ test(args, G_A2B, G_B2A, testA_loader, testB_loader, 'normal', i)
284
+ test(args, G_A2B_ema, G_B2A_ema, testA_loader, testB_loader, 'ema', i)
285
+
286
+ if (i+1) % 2000 == 0:
287
+ torch.save(
288
+ {
289
+ 'G_A2B': G_A2B_module.state_dict(),
290
+ 'G_B2A': G_B2A_module.state_dict(),
291
+ 'G_A2B_ema': G_A2B_ema.state_dict(),
292
+ 'G_B2A_ema': G_B2A_ema.state_dict(),
293
+ 'D_A': D_A_module.state_dict(),
294
+ 'D_B': D_B_module.state_dict(),
295
+ 'D_L': D_L_module.state_dict(),
296
+ 'G_optim': G_optim.state_dict(),
297
+ 'D_optim': D_optim.state_dict(),
298
+ 'iter': i,
299
+ },
300
+ os.path.join(model_path, 'ck.pt'),
301
+ )
302
+
303
+
304
+ if __name__ == '__main__':
305
+ device = 'cuda'
306
+
307
+ parser = argparse.ArgumentParser()
308
+
309
+ parser.add_argument('--iter', type=int, default=300000)
310
+ parser.add_argument('--batch', type=int, default=4)
311
+ parser.add_argument('--n_sample', type=int, default=64)
312
+ parser.add_argument('--size', type=int, default=256)
313
+ parser.add_argument('--r1', type=float, default=10)
314
+ parser.add_argument('--lambda_cycle', type=int, default=1)
315
+ parser.add_argument('--path_regularize', type=float, default=2)
316
+ parser.add_argument('--path_batch_shrink', type=int, default=2)
317
+ parser.add_argument('--d_reg_every', type=int, default=16)
318
+ parser.add_argument('--g_reg_every', type=int, default=4)
319
+ parser.add_argument('--mixing', type=float, default=0.9)
320
+ parser.add_argument('--ckpt', type=str, default=None)
321
+ parser.add_argument('--lr', type=float, default=2e-3)
322
+ parser.add_argument('--local_rank', type=int, default=0)
323
+ parser.add_argument('--num_down', type=int, default=3)
324
+ parser.add_argument('--name', type=str, required=True)
325
+ parser.add_argument('--d_path', type=str, required=True)
326
+ parser.add_argument('--latent_dim', type=int, default=8)
327
+ parser.add_argument('--lr_mlp', type=float, default=0.01)
328
+ parser.add_argument('--n_res', type=int, default=1)
329
+
330
+ args = parser.parse_args()
331
+
332
+ n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
333
+ args.distributed = False
334
+
335
+ if args.distributed:
336
+ torch.cuda.set_device(args.local_rank)
337
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
338
+ synchronize()
339
+
340
+ save_path = f'./{args.name}'
341
+ im_path = os.path.join(save_path, 'sample')
342
+ model_path = os.path.join(save_path, 'checkpoint')
343
+ os.makedirs(im_path, exist_ok=True)
344
+ os.makedirs(model_path, exist_ok=True)
345
+
346
+ args.n_mlp = 5
347
+
348
+ args.start_iter = 0
349
+
350
+ G_A2B = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device)
351
+ D_A = Discriminator(args.size).to(device)
352
+ G_B2A = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device)
353
+ D_B = Discriminator(args.size).to(device)
354
+ D_L = LatDiscriminator(args.latent_dim).to(device)
355
+ lpips_fn = lpips.LPIPS(net='vgg').to(device)
356
+
357
+ G_A2B_ema = copy.deepcopy(G_A2B).to(device).eval()
358
+ G_B2A_ema = copy.deepcopy(G_B2A).to(device).eval()
359
+
360
+ g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
361
+ d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
362
+
363
+ G_optim = optim.Adam( list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=args.lr, betas=(0, 0.99))
364
+ D_optim = optim.Adam(
365
+ list(D_L.parameters()) + list(D_A.parameters()) + list(D_B.parameters()),
366
+ lr=args.lr, betas=(0**d_reg_ratio, 0.99**d_reg_ratio))
367
+
368
+ if args.ckpt is not None:
369
+ ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
370
+
371
+ try:
372
+ ckpt_name = os.path.basename(args.ckpt)
373
+ args.start_iter = int(os.path.splitext(ckpt_name)[0])
374
+
375
+ except ValueError:
376
+ pass
377
+
378
+ G_A2B.load_state_dict(ckpt['G_A2B'])
379
+ G_B2A.load_state_dict(ckpt['G_B2A'])
380
+ G_A2B_ema.load_state_dict(ckpt['G_A2B_ema'])
381
+ G_B2A_ema.load_state_dict(ckpt['G_B2A_ema'])
382
+ D_A.load_state_dict(ckpt['D_A'])
383
+ D_B.load_state_dict(ckpt['D_B'])
384
+ D_L.load_state_dict(ckpt['D_L'])
385
+
386
+ G_optim.load_state_dict(ckpt['G_optim'])
387
+ D_optim.load_state_dict(ckpt['D_optim'])
388
+ args.start_iter = ckpt['iter']
389
+
390
+ if args.distributed:
391
+ G_A2B = nn.parallel.DistributedDataParallel(
392
+ G_A2B,
393
+ device_ids=[args.local_rank],
394
+ output_device=args.local_rank,
395
+ broadcast_buffers=False,
396
+ )
397
+
398
+ D_A = nn.parallel.DistributedDataParallel(
399
+ D_A,
400
+ device_ids=[args.local_rank],
401
+ output_device=args.local_rank,
402
+ broadcast_buffers=False,
403
+ )
404
+
405
+ G_B2A = nn.parallel.DistributedDataParallel(
406
+ G_B2A,
407
+ device_ids=[args.local_rank],
408
+ output_device=args.local_rank,
409
+ broadcast_buffers=False,
410
+ )
411
+
412
+ D_B = nn.parallel.DistributedDataParallel(
413
+ D_B,
414
+ device_ids=[args.local_rank],
415
+ output_device=args.local_rank,
416
+ broadcast_buffers=False,
417
+ )
418
+ D_L = nn.parallel.DistributedDataParallel(
419
+ D_L,
420
+ device_ids=[args.local_rank],
421
+ output_device=args.local_rank,
422
+ broadcast_buffers=False,
423
+ )
424
+ train_transform = transforms.Compose([
425
+ transforms.ToTensor(),
426
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
427
+ ])
428
+
429
+ test_transform = transforms.Compose([
430
+ transforms.Resize((args.size, args.size)),
431
+ transforms.ToTensor(),
432
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
433
+ ])
434
+
435
+ aug = nn.Sequential(
436
+ K.RandomAffine(degrees=(-20,20), scale=(0.8, 1.2), translate=(0.1, 0.1), shear=0.15),
437
+ kornia.geometry.transform.Resize(256+30),
438
+ K.RandomCrop((256,256)),
439
+ K.RandomHorizontalFlip(),
440
+ )
441
+
442
+
443
+ d_path = args.d_path
444
+ trainA = ImageFolder(os.path.join(d_path, 'trainA'), train_transform)
445
+ trainB = ImageFolder(os.path.join(d_path, 'trainB'), train_transform)
446
+ testA = ImageFolder(os.path.join(d_path, 'testA'), test_transform)
447
+ testB = ImageFolder(os.path.join(d_path, 'testB'), test_transform)
448
+
449
+ trainA_loader = data.DataLoader(trainA, batch_size=args.batch,
450
+ sampler=data_sampler(trainA, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5)
451
+ trainB_loader = data.DataLoader(trainB, batch_size=args.batch,
452
+ sampler=data_sampler(trainB, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5)
453
+
454
+ testA_loader = data.DataLoader(testA, batch_size=1, shuffle=False)
455
+ testB_loader = data.DataLoader(testB, batch_size=1, shuffle=False)
456
+
457
+
458
+ train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device)
util.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils import data
4
+ from torch import nn, autograd
5
+ import os
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ google_drive_paths = {
10
+ "GNR_checkpoint.pt": "https://drive.google.com/uc?id=1IMIVke4WDaGayUa7vk_xVw1uqIHikGtC",
11
+ }
12
+
13
+ def ensure_checkpoint_exists(model_weights_filename):
14
+ if not os.path.isfile(model_weights_filename) and (
15
+ model_weights_filename in google_drive_paths
16
+ ):
17
+ gdrive_url = google_drive_paths[model_weights_filename]
18
+ try:
19
+ from gdown import download as drive_download
20
+
21
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
22
+ except ModuleNotFoundError:
23
+ print(
24
+ "gdown module not found.",
25
+ "pip3 install gdown or, manually download the checkpoint file:",
26
+ gdrive_url
27
+ )
28
+
29
+ if not os.path.isfile(model_weights_filename) and (
30
+ model_weights_filename not in google_drive_paths
31
+ ):
32
+ print(
33
+ model_weights_filename,
34
+ " not found, you may need to manually download the model weights."
35
+ )
36
+
37
+ def shuffle_batch(x):
38
+ return x[torch.randperm(x.size(0))]
39
+
40
+ def data_sampler(dataset, shuffle, distributed):
41
+ if distributed:
42
+ return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
43
+
44
+ if shuffle:
45
+ return data.RandomSampler(dataset)
46
+
47
+ else:
48
+ return data.SequentialSampler(dataset)
49
+
50
+
51
+ def accumulate(model1, model2, decay=0.999):
52
+ par1 = dict(model1.named_parameters())
53
+ par2 = dict(model2.named_parameters())
54
+
55
+ for k in par1.keys():
56
+ par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
57
+
58
+
59
+ def sample_data(loader):
60
+ while True:
61
+ for batch in loader:
62
+ yield batch
63
+
64
+
65
+ def d_logistic_loss(real_pred, fake_pred):
66
+ loss = 0
67
+ for real, fake in zip(real_pred, fake_pred):
68
+ real_loss = F.softplus(-real)
69
+ fake_loss = F.softplus(fake)
70
+ loss += real_loss.mean() + fake_loss.mean()
71
+
72
+ return loss
73
+
74
+
75
+ def d_r1_loss(real_pred, real_img):
76
+ grad_penalty = 0
77
+ for real in real_pred:
78
+ grad_real, = autograd.grad(
79
+ outputs=real.mean(), inputs=real_img, create_graph=True, only_inputs=True
80
+ )
81
+ grad_penalty += grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
82
+
83
+ return grad_penalty
84
+
85
+
86
+ def g_nonsaturating_loss(fake_pred, weights):
87
+ loss = 0
88
+ for fake, weight in zip(fake_pred, weights):
89
+ loss += weight*F.softplus(-fake).mean()
90
+
91
+ return loss / len(fake_pred)
92
+
93
+ def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
94
+ # image is [3,h,w] or [1,3,h,w] tensor [0,1]
95
+ if image.is_cuda:
96
+ image = image.cpu()
97
+ if size is not None and image.size(-1) != size:
98
+ image = F.interpolate(image, size=(size,size), mode=mode)
99
+ if image.dim() == 4:
100
+ image = image[0]
101
+ image = image.permute(1, 2, 0).detach().numpy()
102
+ plt.figure()
103
+ plt.title(title)
104
+ plt.axis('off')
105
+ plt.imshow(image)
106
+
107
+ def normalize(x):
108
+ return ((x+1)/2).clamp(0,1)
109
+
110
+ def get_boundingbox(face, width, height, scale=1.3, minsize=None):
111
+ """
112
+ Expects a dlib face to generate a quadratic bounding box.
113
+ :param face: dlib face class
114
+ :param width: frame width
115
+ :param height: frame height
116
+ :param scale: bounding box size multiplier to get a bigger face region
117
+ :param minsize: set minimum bounding box size
118
+ :return: x, y, bounding_box_size in opencv form
119
+ """
120
+ x1 = face.left()
121
+ y1 = face.top()
122
+ x2 = face.right()
123
+ y2 = face.bottom()
124
+ size_bb = int(max(x2 - x1, y2 - y1) * scale)
125
+ if minsize:
126
+ if size_bb < minsize:
127
+ size_bb = minsize
128
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
129
+
130
+ # Check for out of bounds, x-y top left corner
131
+ x1 = max(int(center_x - size_bb // 2), 0)
132
+ y1 = max(int(center_y - size_bb // 2), 0)
133
+ # Check for too big bb size for given x, y
134
+ size_bb = min(width - x1, size_bb)
135
+ size_bb = min(height - y1, size_bb)
136
+
137
+ return x1, y1, size_bb
138
+
139
+
140
+ def preprocess_image(image, cuda=True):
141
+ """
142
+ Preprocesses the image such that it can be fed into our network.
143
+ During this process we envoke PIL to cast it into a PIL image.
144
+ :param image: numpy image in opencv form (i.e., BGR and of shape
145
+ :return: pytorch tensor of shape [1, 3, image_size, image_size], not
146
+ necessarily casted to cuda
147
+ """
148
+ # Revert from BGR
149
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
150
+ # Preprocess using the preprocessing function used during training and
151
+ # casting it to PIL image
152
+ preprocess = xception_default_data_transforms['test']
153
+ preprocessed_image = preprocess(pil_image.fromarray(image))
154
+ # Add first dimension as the network expects a batch
155
+ preprocessed_image = preprocessed_image.unsqueeze(0)
156
+ if cuda:
157
+ preprocessed_image = preprocessed_image.cuda()
158
+ return preprocessed_image
159
+
160
+ def truncate(x, truncation, mean_style):
161
+ return truncation*x + (1-truncation)*mean_style