johnowhitaker commited on
Commit
c9bf0bf
1 Parent(s): 5e36def

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py CHANGED
@@ -19,8 +19,194 @@ from fastprogress.fastprogress import master_bar, progress_bar
19
  from IPython.display import HTML
20
  from base64 import b64encode
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def generate(text, n_steps):
 
 
24
  #todo
25
  return np.random.random((128, 128, 3)).astype(np.uint8)
26
 
 
19
  from IPython.display import HTML
20
  from base64 import b64encode
21
 
22
+ # Definitions
23
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
24
+
25
+ def sinc(x):
26
+ return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
27
+
28
+
29
+ def lanczos(x, a):
30
+ cond = torch.logical_and(-a < x, x < a)
31
+ out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
32
+ return out / out.sum()
33
+
34
+
35
+ def ramp(ratio, width):
36
+ n = math.ceil(width / ratio + 1)
37
+ out = torch.empty([n])
38
+ cur = 0
39
+ for i in range(out.shape[0]):
40
+ out[i] = cur
41
+ cur += ratio
42
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
43
+
44
+ class Prompt(nn.Module):
45
+ def __init__(self, embed, weight=1., stop=float('-inf')):
46
+ super().__init__()
47
+ self.register_buffer('embed', embed)
48
+ self.register_buffer('weight', torch.as_tensor(weight))
49
+ self.register_buffer('stop', torch.as_tensor(stop))
50
+
51
+ def forward(self, input):
52
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
53
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
54
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
55
+ dists = dists * self.weight.sign()
56
+ return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
57
+
58
+ class MakeCutouts(nn.Module):
59
+ def __init__(self, cut_size, cutn, cut_pow=1.):
60
+ super().__init__()
61
+ self.cut_size = cut_size
62
+ self.cutn = cutn
63
+ self.cut_pow = cut_pow
64
+ self.augs = nn.Sequential(
65
+ K.RandomHorizontalFlip(p=0.5),
66
+ K.RandomSharpness(0.3,p=0.4),
67
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
68
+ K.RandomPerspective(0.2,p=0.4),
69
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
70
+ self.noise_fac = 0.1
71
+
72
+ def forward(self, input):
73
+ sideY, sideX = input.shape[2:4]
74
+ max_size = min(sideX, sideY)
75
+ min_size = min(sideX, sideY, self.cut_size)
76
+ cutouts = []
77
+ for _ in range(self.cutn):
78
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
79
+ offsetx = torch.randint(0, sideX - size + 1, ())
80
+ offsety = torch.randint(0, sideY - size + 1, ())
81
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
82
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
83
+ batch = self.augs(torch.cat(cutouts, dim=0))
84
+ if self.noise_fac:
85
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
86
+ batch = batch + facs * torch.randn_like(batch)
87
+ return batch
88
+
89
+ def resample(input, size, align_corners=True):
90
+ n, c, h, w = input.shape
91
+ dh, dw = size
92
+
93
+ input = input.view([n * c, 1, h, w])
94
+
95
+ if dh < h:
96
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
97
+ pad_h = (kernel_h.shape[0] - 1) // 2
98
+ input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
99
+ input = F.conv2d(input, kernel_h[None, None, :, None])
100
+
101
+ if dw < w:
102
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
103
+ pad_w = (kernel_w.shape[0] - 1) // 2
104
+ input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
105
+ input = F.conv2d(input, kernel_w[None, None, None, :])
106
+
107
+ input = input.view([n, c, h, w])
108
+ return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
109
+
110
+ class ReplaceGrad(torch.autograd.Function):
111
+ @staticmethod
112
+ def forward(ctx, x_forward, x_backward):
113
+ ctx.shape = x_backward.shape
114
+ return x_forward
115
+
116
+ @staticmethod
117
+ def backward(ctx, grad_in):
118
+ return None, grad_in.sum_to_size(ctx.shape)
119
+
120
+
121
+ replace_grad = ReplaceGrad.apply
122
+
123
+ # Set up CLIP
124
+ perceptor = clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)
125
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
126
+ std=[0.26862954, 0.26130258, 0.27577711])
127
+ cut_size = perceptor.visual.input_resolution
128
+ cutn=64
129
+ cut_pow=1
130
+ make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow)
131
+
132
+ # ImStack
133
+ class ImStack(nn.Module):
134
+ """ This class represents an image as a series of stacked arrays, where each is 1/2
135
+ the resolution of the next. This is useful eg when trying to create an image to minimise
136
+ some loss - parameters in the early (small) layers can have an affect on the overall
137
+ structure and shapes while those in later layers act as residuals and fill in fine detail.
138
+ """
139
+
140
+ def __init__(self, n_layers=3, base_size=32, scale=2,
141
+ init_image=None, out_size=256, decay=0.7):
142
+ """Constructs the Image Stack
143
+
144
+ Args:
145
+ TODO
146
+ """
147
+ super().__init__()
148
+ self.n_layers = n_layers
149
+ self.base_size = base_size
150
+ self.sig = nn.Sigmoid()
151
+ self.layers = []
152
+
153
+ for i in range(n_layers):
154
+ side = base_size * (scale**i)
155
+ tim = torch.randn((3, side, side)).to(device)*(decay**i)
156
+ self.layers.append(tim)
157
+
158
+ self.scalers = [nn.Upsample(scale_factor=out_size/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers]
159
+
160
+ self.preview_scalers = [nn.Upsample(scale_factor=224/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers]
161
+
162
+ if init_image != None: # Given a PIL image, decompose it into a stack
163
+ downscalers = [nn.Upsample(scale_factor=(l.shape[1]/out_size), mode='bilinear', align_corners=False) for l in self.layers]
164
+ final_side = base_size * (scale ** n_layers)
165
+ im = torch.tensor(np.array(init_image.resize((out_size, out_size)))/255).clip(1e-03, 1-1e-3) # Between 0 and 1 (non-inclusive)
166
+ im = im.permute(2, 0, 1).unsqueeze(0).to(device) # torch.log(im/(1-im))
167
+ for i in range(n_layers):self.layers[i] *= 0 # Sero out the layers
168
+ for i in range(n_layers):
169
+ side = base_size * (scale**i)
170
+ out = self.forward()
171
+ residual = (torch.logit(im) - torch.logit(out))
172
+ Image.fromarray((torch.logit(residual).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)).save(f'residual{i}.png')
173
+ self.layers[i] = downscalers[i](residual).squeeze()
174
+
175
+ for l in self.layers: l.requires_grad = True
176
+
177
+ def forward(self):
178
+ im = self.scalers[0](self.layers[0].unsqueeze(0))
179
+ for i in range(1, self.n_layers):
180
+ im += self.scalers[i](self.layers[i].unsqueeze(0))
181
+ return self.sig(im)
182
+
183
+ def preview(self, n_preview=2):
184
+ im = self.preview_scalers[0](self.layers[0].unsqueeze(0))
185
+ for i in range(1, n_preview):
186
+ im += self.preview_scalers[i](self.layers[i].unsqueeze(0))
187
+ return self.sig(im)
188
+
189
+ def to_pil(self):
190
+ return Image.fromarray((self.forward().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8))
191
+
192
+ def preview_pil(self):
193
+ return Image.fromarray((self.preview().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8))
194
+
195
+ def save(self, fn):
196
+ self.to_pil().save(fn)
197
+
198
+ def plot_layers(self):
199
+ fig, axs = plt.subplots(1, self.n_layers, figsize=(15, 5))
200
+ for i in range(self.n_layers):
201
+ im = (self.sig(self.layers[i].unsqueeze(0)).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)
202
+ axs[i].imshow(im)
203
+
204
+
205
+
206
 
207
  def generate(text, n_steps):
208
+ # Encode prompt
209
+ embed = perceptor.encode_text(clip.tokenize(text).to(device)).float()
210
  #todo
211
  return np.random.random((128, 128, 3)).astype(np.uint8)
212