johnowhitaker commited on
Commit
d24b25a
1 Parent(s): 9d691c7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +636 -0
app.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Gradio demo (used in space: )
2
+
3
+ from matplotlib import pyplot as plt
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ import numpy as np
6
+ import gradio as gr
7
+
8
+
9
+ #@title Defining Generator and associated code ourselves without the GPU requirements
10
+ import os
11
+ import json
12
+ import multiprocessing
13
+ from random import random
14
+ import math
15
+ from math import log2, floor
16
+ from functools import partial
17
+ from contextlib import contextmanager, ExitStack
18
+ from pathlib import Path
19
+ from shutil import rmtree
20
+
21
+ import torch
22
+ from torch.cuda.amp import autocast, GradScaler
23
+ from torch.optim import Adam
24
+ from torch import nn, einsum
25
+ import torch.nn.functional as F
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from torch.autograd import grad as torch_grad
28
+ from torch.utils.data.distributed import DistributedSampler
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+
31
+ from PIL import Image
32
+ import torchvision
33
+ from torchvision import transforms
34
+ from kornia.filters import filter2d
35
+
36
+ from lightweight_gan.diff_augment import DiffAugment
37
+ from lightweight_gan.version import __version__
38
+
39
+ from tqdm import tqdm
40
+ from einops import rearrange, reduce, repeat
41
+
42
+ from adabelief_pytorch import AdaBelief
43
+
44
+ # helpers
45
+
46
+ def exists(val):
47
+ return val is not None
48
+
49
+ @contextmanager
50
+ def null_context():
51
+ yield
52
+
53
+ def combine_contexts(contexts):
54
+ @contextmanager
55
+ def multi_contexts():
56
+ with ExitStack() as stack:
57
+ yield [stack.enter_context(ctx()) for ctx in contexts]
58
+ return multi_contexts
59
+
60
+ def is_power_of_two(val):
61
+ return log2(val).is_integer()
62
+
63
+ def default(val, d):
64
+ return val if exists(val) else d
65
+
66
+ def set_requires_grad(model, bool):
67
+ for p in model.parameters():
68
+ p.requires_grad = bool
69
+
70
+ def cycle(iterable):
71
+ while True:
72
+ for i in iterable:
73
+ yield i
74
+
75
+ def raise_if_nan(t):
76
+ if torch.isnan(t):
77
+ raise NanException
78
+
79
+ def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps):
80
+ if is_ddp:
81
+ num_no_syncs = gradient_accumulate_every - 1
82
+ head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs
83
+ tail = [null_context]
84
+ contexts = head + tail
85
+ else:
86
+ contexts = [null_context] * gradient_accumulate_every
87
+
88
+ for context in contexts:
89
+ with context():
90
+ yield
91
+
92
+ def evaluate_in_chunks(max_batch_size, model, *args):
93
+ split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
94
+ chunked_outputs = [model(*i) for i in split_args]
95
+ if len(chunked_outputs) == 1:
96
+ return chunked_outputs[0]
97
+ return torch.cat(chunked_outputs, dim=0)
98
+
99
+ def slerp(val, low, high):
100
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
101
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
102
+ omega = torch.acos((low_norm * high_norm).sum(1))
103
+ so = torch.sin(omega)
104
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
105
+ return res
106
+
107
+ def safe_div(n, d):
108
+ try:
109
+ res = n / d
110
+ except ZeroDivisionError:
111
+ prefix = '' if int(n >= 0) else '-'
112
+ res = float(f'{prefix}inf')
113
+ return res
114
+
115
+ # loss functions
116
+
117
+ def gen_hinge_loss(fake, real):
118
+ return fake.mean()
119
+
120
+ def hinge_loss(real, fake):
121
+ return (F.relu(1 + real) + F.relu(1 - fake)).mean()
122
+
123
+ def dual_contrastive_loss(real_logits, fake_logits):
124
+ device = real_logits.device
125
+ real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))
126
+
127
+ def loss_half(t1, t2):
128
+ t1 = rearrange(t1, 'i -> i ()')
129
+ t2 = repeat(t2, 'j -> i j', i = t1.shape[0])
130
+ t = torch.cat((t1, t2), dim = -1)
131
+ return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))
132
+
133
+ return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)
134
+
135
+ # helper classes
136
+
137
+ class NanException(Exception):
138
+ pass
139
+
140
+ class EMA():
141
+ def __init__(self, beta):
142
+ super().__init__()
143
+ self.beta = beta
144
+ def update_average(self, old, new):
145
+ if not exists(old):
146
+ return new
147
+ return old * self.beta + (1 - self.beta) * new
148
+
149
+ class RandomApply(nn.Module):
150
+ def __init__(self, prob, fn, fn_else = lambda x: x):
151
+ super().__init__()
152
+ self.fn = fn
153
+ self.fn_else = fn_else
154
+ self.prob = prob
155
+ def forward(self, x):
156
+ fn = self.fn if random() < self.prob else self.fn_else
157
+ return fn(x)
158
+
159
+ class ChanNorm(nn.Module):
160
+ def __init__(self, dim, eps = 1e-5):
161
+ super().__init__()
162
+ self.eps = eps
163
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
164
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
165
+
166
+ def forward(self, x):
167
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
168
+ mean = torch.mean(x, dim = 1, keepdim = True)
169
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
170
+
171
+ class PreNorm(nn.Module):
172
+ def __init__(self, dim, fn):
173
+ super().__init__()
174
+ self.fn = fn
175
+ self.norm = ChanNorm(dim)
176
+
177
+ def forward(self, x):
178
+ return self.fn(self.norm(x))
179
+
180
+ class Residual(nn.Module):
181
+ def __init__(self, fn):
182
+ super().__init__()
183
+ self.fn = fn
184
+
185
+ def forward(self, x):
186
+ return self.fn(x) + x
187
+
188
+ class SumBranches(nn.Module):
189
+ def __init__(self, branches):
190
+ super().__init__()
191
+ self.branches = nn.ModuleList(branches)
192
+ def forward(self, x):
193
+ return sum(map(lambda fn: fn(x), self.branches))
194
+
195
+ class Blur(nn.Module):
196
+ def __init__(self):
197
+ super().__init__()
198
+ f = torch.Tensor([1, 2, 1])
199
+ self.register_buffer('f', f)
200
+ def forward(self, x):
201
+ f = self.f
202
+ f = f[None, None, :] * f [None, :, None]
203
+ return filter2d(x, f, normalized=True)
204
+
205
+ class Noise(nn.Module):
206
+ def __init__(self):
207
+ super().__init__()
208
+ self.weight = nn.Parameter(torch.zeros(1))
209
+
210
+ def forward(self, x, noise = None):
211
+ b, _, h, w, device = *x.shape, x.device
212
+
213
+ if not exists(noise):
214
+ noise = torch.randn(b, 1, h, w, device = device)
215
+
216
+ return x + self.weight * noise
217
+
218
+ def Conv2dSame(dim_in, dim_out, kernel_size, bias = True):
219
+ pad_left = kernel_size // 2
220
+ pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
221
+
222
+ return nn.Sequential(
223
+ nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
224
+ nn.Conv2d(dim_in, dim_out, kernel_size, bias = bias)
225
+ )
226
+
227
+ # attention
228
+
229
+ class DepthWiseConv2d(nn.Module):
230
+ def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True):
231
+ super().__init__()
232
+ self.net = nn.Sequential(
233
+ nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
234
+ nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
235
+ )
236
+ def forward(self, x):
237
+ return self.net(x)
238
+
239
+ class LinearAttention(nn.Module):
240
+ def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3):
241
+ super().__init__()
242
+ self.scale = dim_head ** -0.5
243
+ self.heads = heads
244
+ self.dim_head = dim_head
245
+ inner_dim = dim_head * heads
246
+
247
+ self.kernel_size = kernel_size
248
+ self.nonlin = nn.GELU()
249
+
250
+ self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
251
+ self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)
252
+
253
+ self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
254
+ self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
255
+
256
+ self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
257
+
258
+ def forward(self, fmap):
259
+ h, x, y = self.heads, *fmap.shape[-2:]
260
+
261
+ # linear attention
262
+
263
+ lin_q, lin_k, lin_v = (self.to_lin_q(fmap), *self.to_lin_kv(fmap).chunk(2, dim = 1))
264
+ lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (lin_q, lin_k, lin_v))
265
+
266
+ lin_q = lin_q.softmax(dim = -1)
267
+ lin_k = lin_k.softmax(dim = -2)
268
+
269
+ lin_q = lin_q * self.scale
270
+
271
+ context = einsum('b n d, b n e -> b d e', lin_k, lin_v)
272
+ lin_out = einsum('b n d, b d e -> b n e', lin_q, context)
273
+ lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
274
+
275
+ # conv-like full attention
276
+
277
+ q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
278
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v))
279
+
280
+ k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
281
+ v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
282
+
283
+ k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v))
284
+
285
+ q = rearrange(q, 'b c ... -> b (...) c') * self.scale
286
+
287
+ sim = einsum('b i d, b i j d -> b i j', q, k)
288
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
289
+
290
+ attn = sim.softmax(dim = -1)
291
+
292
+ full_out = einsum('b i j, b i j d -> b i d', attn, v)
293
+ full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
294
+
295
+ # add outputs of linear attention + conv like full attention
296
+
297
+ lin_out = self.nonlin(lin_out)
298
+ out = torch.cat((lin_out, full_out), dim = 1)
299
+ return self.to_out(out)
300
+
301
+ # dataset
302
+
303
+ def convert_image_to(img_type, image):
304
+ if image.mode != img_type:
305
+ return image.convert(img_type)
306
+ return image
307
+
308
+ class identity(object):
309
+ def __call__(self, tensor):
310
+ return tensor
311
+
312
+ class expand_greyscale(object):
313
+ def __init__(self, transparent):
314
+ self.transparent = transparent
315
+
316
+ def __call__(self, tensor):
317
+ channels = tensor.shape[0]
318
+ num_target_channels = 4 if self.transparent else 3
319
+
320
+ if channels == num_target_channels:
321
+ return tensor
322
+
323
+ alpha = None
324
+ if channels == 1:
325
+ color = tensor.expand(3, -1, -1)
326
+ elif channels == 2:
327
+ color = tensor[:1].expand(3, -1, -1)
328
+ alpha = tensor[1:]
329
+ else:
330
+ raise Exception(f'image with invalid number of channels given {channels}')
331
+
332
+ if not exists(alpha) and self.transparent:
333
+ alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
334
+
335
+ return color if not self.transparent else torch.cat((color, alpha))
336
+
337
+ def resize_to_minimum_size(min_size, image):
338
+ if max(*image.size) < min_size:
339
+ return torchvision.transforms.functional.resize(image, min_size)
340
+ return image
341
+
342
+ class ImageDataset(Dataset):
343
+ def __init__(
344
+ self,
345
+ folder,
346
+ image_size,
347
+ transparent = False,
348
+ greyscale = False,
349
+ aug_prob = 0.
350
+ ):
351
+ super().__init__()
352
+ self.folder = folder
353
+ self.image_size = image_size
354
+ self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
355
+ assert len(self.paths) > 0, f'No images were found in {folder} for training'
356
+
357
+ if transparent:
358
+ num_channels = 4
359
+ pillow_mode = 'RGBA'
360
+ expand_fn = expand_greyscale(transparent)
361
+ elif greyscale:
362
+ num_channels = 1
363
+ pillow_mode = 'L'
364
+ expand_fn = identity()
365
+ else:
366
+ num_channels = 3
367
+ pillow_mode = 'RGB'
368
+ expand_fn = expand_greyscale(transparent)
369
+
370
+ convert_image_fn = partial(convert_image_to, pillow_mode)
371
+
372
+ self.transform = transforms.Compose([
373
+ transforms.Lambda(convert_image_fn),
374
+ transforms.Lambda(partial(resize_to_minimum_size, image_size)),
375
+ transforms.Resize(image_size),
376
+ RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
377
+ transforms.ToTensor(),
378
+ transforms.Lambda(expand_fn)
379
+ ])
380
+
381
+ def __len__(self):
382
+ return len(self.paths)
383
+
384
+ def __getitem__(self, index):
385
+ path = self.paths[index]
386
+ img = Image.open(path)
387
+ return self.transform(img)
388
+
389
+ # augmentations
390
+
391
+ def random_hflip(tensor, prob):
392
+ if prob > random():
393
+ return tensor
394
+ return torch.flip(tensor, dims=(3,))
395
+
396
+ class AugWrapper(nn.Module):
397
+ def __init__(self, D, image_size):
398
+ super().__init__()
399
+ self.D = D
400
+
401
+ def forward(self, images, prob = 0., types = [], detach = False, **kwargs):
402
+ context = torch.no_grad if detach else null_context
403
+
404
+ with context():
405
+ if random() < prob:
406
+ images = random_hflip(images, prob=0.5)
407
+ images = DiffAugment(images, types=types)
408
+
409
+ return self.D(images, **kwargs)
410
+
411
+ # modifiable global variables
412
+
413
+ norm_class = nn.BatchNorm2d
414
+
415
+ def upsample(scale_factor = 2):
416
+ return nn.Upsample(scale_factor = scale_factor)
417
+
418
+ # squeeze excitation classes
419
+
420
+ # global context network
421
+ # https://arxiv.org/abs/2012.13375
422
+ # similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm
423
+
424
+ class GlobalContext(nn.Module):
425
+ def __init__(
426
+ self,
427
+ *,
428
+ chan_in,
429
+ chan_out
430
+ ):
431
+ super().__init__()
432
+ self.to_k = nn.Conv2d(chan_in, 1, 1)
433
+ chan_intermediate = max(3, chan_out // 2)
434
+
435
+ self.net = nn.Sequential(
436
+ nn.Conv2d(chan_in, chan_intermediate, 1),
437
+ nn.LeakyReLU(0.1),
438
+ nn.Conv2d(chan_intermediate, chan_out, 1),
439
+ nn.Sigmoid()
440
+ )
441
+ def forward(self, x):
442
+ context = self.to_k(x)
443
+ context = context.flatten(2).softmax(dim = -1)
444
+ out = einsum('b i n, b c n -> b c i', context, x.flatten(2))
445
+ out = out.unsqueeze(-1)
446
+ return self.net(out)
447
+
448
+ # frequency channel attention
449
+ # https://arxiv.org/abs/2012.11879
450
+
451
+ def get_1d_dct(i, freq, L):
452
+ result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
453
+ return result * (1 if freq == 0 else math.sqrt(2))
454
+
455
+ def get_dct_weights(width, channel, fidx_u, fidx_v):
456
+ dct_weights = torch.zeros(1, channel, width, width)
457
+ c_part = channel // len(fidx_u)
458
+
459
+ for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
460
+ for x in range(width):
461
+ for y in range(width):
462
+ coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
463
+ dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value
464
+
465
+ return dct_weights
466
+
467
+ class FCANet(nn.Module):
468
+ def __init__(
469
+ self,
470
+ *,
471
+ chan_in,
472
+ chan_out,
473
+ reduction = 4,
474
+ width
475
+ ):
476
+ super().__init__()
477
+
478
+ freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
479
+ dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
480
+ self.register_buffer('dct_weights', dct_weights)
481
+
482
+ chan_intermediate = max(3, chan_out // reduction)
483
+
484
+ self.net = nn.Sequential(
485
+ nn.Conv2d(chan_in, chan_intermediate, 1),
486
+ nn.LeakyReLU(0.1),
487
+ nn.Conv2d(chan_intermediate, chan_out, 1),
488
+ nn.Sigmoid()
489
+ )
490
+
491
+ def forward(self, x):
492
+ x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1)
493
+ return self.net(x)
494
+
495
+ # generative adversarial network
496
+
497
+ class Generator(nn.Module):
498
+ def __init__(
499
+ self,
500
+ *,
501
+ image_size,
502
+ latent_dim = 256,
503
+ fmap_max = 512,
504
+ fmap_inverse_coef = 12,
505
+ transparent = False,
506
+ greyscale = False,
507
+ attn_res_layers = [],
508
+ freq_chan_attn = False
509
+ ):
510
+ super().__init__()
511
+ resolution = log2(image_size)
512
+ assert is_power_of_two(image_size), 'image size must be a power of 2'
513
+
514
+ if transparent:
515
+ init_channel = 4
516
+ elif greyscale:
517
+ init_channel = 1
518
+ else:
519
+ init_channel = 3
520
+
521
+ fmap_max = default(fmap_max, latent_dim)
522
+
523
+ self.initial_conv = nn.Sequential(
524
+ nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
525
+ norm_class(latent_dim * 2),
526
+ nn.GLU(dim = 1)
527
+ )
528
+
529
+ num_layers = int(resolution) - 2
530
+ features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
531
+ features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
532
+ features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
533
+ features = [latent_dim, *features]
534
+
535
+ in_out_features = list(zip(features[:-1], features[1:]))
536
+
537
+ self.res_layers = range(2, num_layers + 2)
538
+ self.layers = nn.ModuleList([])
539
+ self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
540
+
541
+ self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
542
+ self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
543
+ self.sle_map = dict(self.sle_map)
544
+
545
+ self.num_layers_spatial_res = 1
546
+
547
+ for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
548
+ image_width = 2 ** res
549
+
550
+ attn = None
551
+ if image_width in attn_res_layers:
552
+ attn = PreNorm(chan_in, LinearAttention(chan_in))
553
+
554
+ sle = None
555
+ if res in self.sle_map:
556
+ residual_layer = self.sle_map[res]
557
+ sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
558
+
559
+ if freq_chan_attn:
560
+ sle = FCANet(
561
+ chan_in = chan_out,
562
+ chan_out = sle_chan_out,
563
+ width = 2 ** (res + 1)
564
+ )
565
+ else:
566
+ sle = GlobalContext(
567
+ chan_in = chan_out,
568
+ chan_out = sle_chan_out
569
+ )
570
+
571
+ layer = nn.ModuleList([
572
+ nn.Sequential(
573
+ upsample(),
574
+ Blur(),
575
+ Conv2dSame(chan_in, chan_out * 2, 4),
576
+ Noise(),
577
+ norm_class(chan_out * 2),
578
+ nn.GLU(dim = 1)
579
+ ),
580
+ sle,
581
+ attn
582
+ ])
583
+ self.layers.append(layer)
584
+
585
+ self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1)
586
+
587
+ def forward(self, x):
588
+ x = rearrange(x, 'b c -> b c () ()')
589
+ x = self.initial_conv(x)
590
+ x = F.normalize(x, dim = 1)
591
+
592
+ residuals = dict()
593
+
594
+ for (res, (up, sle, attn)) in zip(self.res_layers, self.layers):
595
+ if exists(attn):
596
+ x = attn(x) + x
597
+
598
+ x = up(x)
599
+
600
+ if exists(sle):
601
+ out_res = self.sle_map[res]
602
+ residual = sle(x)
603
+ residuals[out_res] = residual
604
+
605
+ next_res = res + 1
606
+ if next_res in residuals:
607
+ x = x * residuals[next_res]
608
+
609
+ return self.out_conv(x)
610
+
611
+ # Initialize a generator model
612
+ gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32])
613
+
614
+ # Load from local saved state dict
615
+ # gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt'))
616
+
617
+ # Load from model hub:
618
+ class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin):
619
+ pass
620
+ gan_new.__class__ = GeneratorWithPyTorchModelHubMixin
621
+ gan_new = gan_new.from_pretrained('johnowhitaker/colorb_gan', latent_dim=256, image_size=256, attn_res_layers = [32])
622
+
623
+ def gen_ims(n_rows):
624
+ ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.)
625
+ grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy()
626
+ return (grid*255).astype(np.uint8)
627
+
628
+
629
+
630
+ iface = gr.Interface(fn=gen_ims,
631
+ inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows")],
632
+ outputs=[gr.outputs.Image(type="numpy", label="Generated Images")],
633
+ title='Demo for Colorbgan model',
634
+ article = 'A lightweight-gans trained on johnowhitaker/colorbs. See https://huggingface.co/johnowhitaker/orbgan_e1 for training and inference scripts'
635
+ )
636
+ iface.launch()