teticio commited on
Commit
9a9737e
1 Parent(s): 7aaaf62

merge with diffusers latest version

Browse files
scripts/train_unconditional.py CHANGED
@@ -35,24 +35,19 @@ def main(args):
35
  output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
36
  logging_dir = os.path.join(output_dir, args.logging_dir)
37
  accelerator = Accelerator(
 
38
  mixed_precision=args.mixed_precision,
39
  log_with="tensorboard",
40
  logging_dir=logging_dir,
41
  )
42
 
 
 
 
43
  if args.from_pretrained is not None:
44
- #model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
45
- pretrained = LDMPipeline.from_pretrained(args.from_pretrained)
46
- vqvae = pretrained.vqvae
47
- model = pretrained.unet
48
  else:
49
- vqvae = AutoencoderKL(sample_size=args.resolution,
50
- in_channels=1,
51
- out_channels=1,
52
- latent_channels=1,
53
- layers_per_block=2)
54
  model = UNet2DModel(
55
- sample_size=args.resolution,
56
  in_channels=1,
57
  out_channels=1,
58
  layers_per_block=2,
@@ -75,10 +70,12 @@ def main(args):
75
  ),
76
  )
77
 
78
- #noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
79
- # tensor_format="pt")
80
- noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
81
- tensor_format="pt")
 
 
82
  optimizer = torch.optim.AdamW(
83
  model.parameters(),
84
  lr=args.learning_rate,
@@ -115,7 +112,13 @@ def main(args):
115
  )
116
 
117
  def transforms(examples):
118
- images = [augmentations(image) for image in examples["image"]]
 
 
 
 
 
 
119
  return {"input": images}
120
 
121
  dataset.set_transform(transforms)
@@ -181,27 +184,42 @@ def main(args):
181
  device=clean_images.device,
182
  ).long()
183
 
184
- clean_latents = vqvae.encode(clean_images)["sample"]
 
 
 
 
185
  # Add noise to the clean images according to the noise magnitude at each timestep
186
  # (this is the forward diffusion process)
187
- noisy_latents = noise_scheduler.add_noise(clean_latents, noise,
188
- timesteps)
189
 
190
  with accelerator.accumulate(model):
191
  # Predict the noise residual
192
- latents = model(noisy_latents, timesteps)["sample"]
193
- noise_pred = vqvae.decode(latents)["sample"]
194
  loss = F.mse_loss(noise_pred, noise)
195
  accelerator.backward(loss)
196
 
197
- accelerator.clip_grad_norm_(model.parameters(), 1.0)
 
198
  optimizer.step()
199
  lr_scheduler.step()
200
  if args.use_ema:
201
  ema_model.step(model)
202
  optimizer.zero_grad()
203
 
204
- progress_bar.update(1)
 
 
 
 
 
 
 
 
 
 
205
  logs = {
206
  "loss": loss.detach().item(),
207
  "lr": lr_scheduler.get_last_lr()[0],
@@ -211,7 +229,6 @@ def main(args):
211
  logs["ema_decay"] = ema_model.decay
212
  progress_bar.set_postfix(**logs)
213
  accelerator.log(logs, step=global_step)
214
- global_step += 1
215
  progress_bar.close()
216
 
217
  accelerator.wait_for_everyone()
@@ -219,17 +236,19 @@ def main(args):
219
  # Generate sample images for visual inspection
220
  if accelerator.is_main_process:
221
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
222
- #pipeline = DDPMPipeline(
223
- # unet=accelerator.unwrap_model(
224
- # ema_model.averaged_model if args.use_ema else model),
225
- # scheduler=noise_scheduler,
226
- #)
227
- pipeline = LDMPipeline(
228
- unet=accelerator.unwrap_model(
229
- ema_model.averaged_model if args.use_ema else model),
230
- vqvae=vqvae,
231
- scheduler=noise_scheduler,
232
- )
 
 
233
 
234
  # save the model
235
  if args.push_to_hub:
@@ -325,6 +344,14 @@ if __name__ == "__main__":
325
  parser.add_argument("--hop_length", type=int, default=512)
326
  parser.add_argument("--from_pretrained", type=str, default=None)
327
  parser.add_argument("--start_epoch", type=int, default=0)
 
 
 
 
 
 
 
 
328
 
329
  args = parser.parse_args()
330
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
35
  output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
36
  logging_dir = os.path.join(output_dir, args.logging_dir)
37
  accelerator = Accelerator(
38
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
39
  mixed_precision=args.mixed_precision,
40
  log_with="tensorboard",
41
  logging_dir=logging_dir,
42
  )
43
 
44
+ if args.vae is not None:
45
+ vqvae = AutoencoderKL.from_pretrained(args.vae)
46
+
47
  if args.from_pretrained is not None:
48
+ model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
 
 
 
49
  else:
 
 
 
 
 
50
  model = UNet2DModel(
 
51
  in_channels=1,
52
  out_channels=1,
53
  layers_per_block=2,
 
70
  ),
71
  )
72
 
73
+ if args.scheduler == "ddpm":
74
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
75
+ tensor_format="pt")
76
+ else:
77
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
78
+ tensor_format="pt")
79
  optimizer = torch.optim.AdamW(
80
  model.parameters(),
81
  lr=args.learning_rate,
 
112
  )
113
 
114
  def transforms(examples):
115
+ if args.vae is not None:
116
+ images = [
117
+ augmentations(image).convert("RGB")
118
+ for image in examples["image"]
119
+ ]
120
+ else:
121
+ images = [augmentations(image) for image in examples["image"]]
122
  return {"input": images}
123
 
124
  dataset.set_transform(transforms)
 
184
  device=clean_images.device,
185
  ).long()
186
 
187
+ if args.vae is not None:
188
+ with torch.no_grad():
189
+ clean_images = vqvae.encode(
190
+ clean_images).latent_dist.sample()
191
+
192
  # Add noise to the clean images according to the noise magnitude at each timestep
193
  # (this is the forward diffusion process)
194
+ noisy_images = noise_scheduler.add_noise(clean_images, noise,
195
+ timesteps)
196
 
197
  with accelerator.accumulate(model):
198
  # Predict the noise residual
199
+ images = model(noisy_images, timesteps)["sample"]
200
+ noise_pred = vqvae.decode(images)["sample"]
201
  loss = F.mse_loss(noise_pred, noise)
202
  accelerator.backward(loss)
203
 
204
+ if accelerator.sync_gradients:
205
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
206
  optimizer.step()
207
  lr_scheduler.step()
208
  if args.use_ema:
209
  ema_model.step(model)
210
  optimizer.zero_grad()
211
 
212
+ if args.vae is not None:
213
+ with torch.no_grad():
214
+ images = [
215
+ image.convert('L')
216
+ for image in vqvae.decode(images)["sample"]
217
+ ]
218
+
219
+ if accelerator.sync_gradients:
220
+ progress_bar.update(1)
221
+ global_step += 1
222
+
223
  logs = {
224
  "loss": loss.detach().item(),
225
  "lr": lr_scheduler.get_last_lr()[0],
 
229
  logs["ema_decay"] = ema_model.decay
230
  progress_bar.set_postfix(**logs)
231
  accelerator.log(logs, step=global_step)
 
232
  progress_bar.close()
233
 
234
  accelerator.wait_for_everyone()
 
236
  # Generate sample images for visual inspection
237
  if accelerator.is_main_process:
238
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
239
+ if args.vae is not None:
240
+ pipeline = LDMPipeline(
241
+ unet=accelerator.unwrap_model(
242
+ ema_model.averaged_model if args.use_ema else model),
243
+ vqvae=vqvae,
244
+ scheduler=noise_scheduler,
245
+ )
246
+ else:
247
+ pipeline = DDPMPipeline(
248
+ unet=accelerator.unwrap_model(
249
+ ema_model.averaged_model if args.use_ema else model),
250
+ scheduler=noise_scheduler,
251
+ )
252
 
253
  # save the model
254
  if args.push_to_hub:
 
344
  parser.add_argument("--hop_length", type=int, default=512)
345
  parser.add_argument("--from_pretrained", type=str, default=None)
346
  parser.add_argument("--start_epoch", type=int, default=0)
347
+ parser.add_argument("--scheduler",
348
+ type=str,
349
+ default="ddpm",
350
+ help="ddpm or ddim")
351
+ parser.add_argument("--vae",
352
+ type=str,
353
+ default=None,
354
+ help="pretrained VAE model for latent diffusion")
355
 
356
  args = parser.parse_args()
357
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
scripts/train_vae.py CHANGED
@@ -6,6 +6,7 @@
6
  # grayscale
7
  # add vae to train_uncond (no_grad)
8
  # update README
 
9
 
10
  import os
11
  import argparse
 
6
  # grayscale
7
  # add vae to train_uncond (no_grad)
8
  # update README
9
+ # merge in changes to train_unconditional
10
 
11
  import os
12
  import argparse