himanshu-skid19 zombie-596 commited on
Commit
0312396
·
1 Parent(s): 9e50096

Update app.py (#5)

Browse files

- Update app.py (90040e744146b3329d6aa32da5f638826be22dbe)


Co-authored-by: Saptarshi Mukherjee <zombie-596@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +60 -1
app.py CHANGED
@@ -107,7 +107,66 @@ class SimpleUnet(nn.Module):
107
  x = up(x, t)
108
  return self.output(x)
109
 
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  model = SimpleUnet()
112
 
113
  st.title("Generatig images using a diffusion model")
 
107
  x = up(x, t)
108
  return self.output(x)
109
 
110
+ def extract(a, t, x_shape):
111
+ batch_size = t.shape[0]
112
+ out = a.gather(-1, t.cpu())
113
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
114
+
115
+ @torch.no_grad()
116
+ def p_sample(model, x, t, t_index):
117
+ betas_t = extract(betas, t, x.shape)
118
+ sqrt_one_minus_alphas_cumprod_t = extract(
119
+ sqrt_one_minus_alphas_cumprod, t, x.shape
120
+ )
121
+ sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
122
+
123
+ # Equation 11 in the paper
124
+ # Use our model (noise predictor) to predict the mean
125
+ model_mean = sqrt_recip_alphas_t * (
126
+ x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
127
+ )
128
+
129
+ if t_index == 0:
130
+ return model_mean
131
+ else:
132
+ posterior_variance_t = extract(posterior_variance, t, x.shape)
133
+ noise = torch.randn_like(x)
134
+ # Algorithm 2 line 4:
135
+ return model_mean + torch.sqrt(posterior_variance_t) * noise
136
+
137
+ # Algorithm 2 but save all images:
138
+ @torch.no_grad()
139
+ def p_sample_loop(model, shape):
140
+ device = next(model.parameters()).device
141
+
142
+ b = shape[0]
143
+ # start from pure noise (for each example in the batch)
144
+ img = torch.randn(shape, device=device)
145
+ imgs = []
146
+
147
+ for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
148
+ img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), 3)
149
+ imgs.append(img.cpu().numpy())
150
+ return imgs
151
+
152
+ @torch.no_grad()
153
+ def sample(model, image_size, batch_size=16, channels=3):
154
+ return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
155
+
156
+ samples = sample(model, image_size=img_size, batch_size=64, channels=3)
157
+
158
+
159
+ reverse_transforms = transforms.Compose([
160
+ transforms.Lambda(lambda t: (t + 1) / 2),
161
+ transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
162
+ transforms.Lambda(lambda t: t * 255.),
163
+ transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
164
+ transforms.ToPILImage(),
165
+ ])
166
+
167
+ for i in range(10):
168
+ img = reverse_transforms(torch.Tensor((samples[-1][i].reshape(3, img_size, img_size))))
169
+ plt.imshow(img)
170
  model = SimpleUnet()
171
 
172
  st.title("Generatig images using a diffusion model")