Commit
·
0312396
1
Parent(s):
9e50096
Update app.py (#5)
Browse files- Update app.py (90040e744146b3329d6aa32da5f638826be22dbe)
Co-authored-by: Saptarshi Mukherjee <zombie-596@users.noreply.huggingface.co>
app.py
CHANGED
|
@@ -107,7 +107,66 @@ class SimpleUnet(nn.Module):
|
|
| 107 |
x = up(x, t)
|
| 108 |
return self.output(x)
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
model = SimpleUnet()
|
| 112 |
|
| 113 |
st.title("Generatig images using a diffusion model")
|
|
|
|
| 107 |
x = up(x, t)
|
| 108 |
return self.output(x)
|
| 109 |
|
| 110 |
+
def extract(a, t, x_shape):
|
| 111 |
+
batch_size = t.shape[0]
|
| 112 |
+
out = a.gather(-1, t.cpu())
|
| 113 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
| 114 |
+
|
| 115 |
+
@torch.no_grad()
|
| 116 |
+
def p_sample(model, x, t, t_index):
|
| 117 |
+
betas_t = extract(betas, t, x.shape)
|
| 118 |
+
sqrt_one_minus_alphas_cumprod_t = extract(
|
| 119 |
+
sqrt_one_minus_alphas_cumprod, t, x.shape
|
| 120 |
+
)
|
| 121 |
+
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
|
| 122 |
+
|
| 123 |
+
# Equation 11 in the paper
|
| 124 |
+
# Use our model (noise predictor) to predict the mean
|
| 125 |
+
model_mean = sqrt_recip_alphas_t * (
|
| 126 |
+
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if t_index == 0:
|
| 130 |
+
return model_mean
|
| 131 |
+
else:
|
| 132 |
+
posterior_variance_t = extract(posterior_variance, t, x.shape)
|
| 133 |
+
noise = torch.randn_like(x)
|
| 134 |
+
# Algorithm 2 line 4:
|
| 135 |
+
return model_mean + torch.sqrt(posterior_variance_t) * noise
|
| 136 |
+
|
| 137 |
+
# Algorithm 2 but save all images:
|
| 138 |
+
@torch.no_grad()
|
| 139 |
+
def p_sample_loop(model, shape):
|
| 140 |
+
device = next(model.parameters()).device
|
| 141 |
+
|
| 142 |
+
b = shape[0]
|
| 143 |
+
# start from pure noise (for each example in the batch)
|
| 144 |
+
img = torch.randn(shape, device=device)
|
| 145 |
+
imgs = []
|
| 146 |
+
|
| 147 |
+
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
|
| 148 |
+
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), 3)
|
| 149 |
+
imgs.append(img.cpu().numpy())
|
| 150 |
+
return imgs
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def sample(model, image_size, batch_size=16, channels=3):
|
| 154 |
+
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
|
| 155 |
+
|
| 156 |
+
samples = sample(model, image_size=img_size, batch_size=64, channels=3)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
reverse_transforms = transforms.Compose([
|
| 160 |
+
transforms.Lambda(lambda t: (t + 1) / 2),
|
| 161 |
+
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
|
| 162 |
+
transforms.Lambda(lambda t: t * 255.),
|
| 163 |
+
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
|
| 164 |
+
transforms.ToPILImage(),
|
| 165 |
+
])
|
| 166 |
+
|
| 167 |
+
for i in range(10):
|
| 168 |
+
img = reverse_transforms(torch.Tensor((samples[-1][i].reshape(3, img_size, img_size))))
|
| 169 |
+
plt.imshow(img)
|
| 170 |
model = SimpleUnet()
|
| 171 |
|
| 172 |
st.title("Generatig images using a diffusion model")
|