Spaces:
Runtime error
Runtime error
merge with diffusers latest version
Browse files- scripts/train_unconditional.py +61 -34
- scripts/train_vae.py +1 -0
scripts/train_unconditional.py
CHANGED
@@ -35,24 +35,19 @@ def main(args):
|
|
35 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
36 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
37 |
accelerator = Accelerator(
|
|
|
38 |
mixed_precision=args.mixed_precision,
|
39 |
log_with="tensorboard",
|
40 |
logging_dir=logging_dir,
|
41 |
)
|
42 |
|
|
|
|
|
|
|
43 |
if args.from_pretrained is not None:
|
44 |
-
|
45 |
-
pretrained = LDMPipeline.from_pretrained(args.from_pretrained)
|
46 |
-
vqvae = pretrained.vqvae
|
47 |
-
model = pretrained.unet
|
48 |
else:
|
49 |
-
vqvae = AutoencoderKL(sample_size=args.resolution,
|
50 |
-
in_channels=1,
|
51 |
-
out_channels=1,
|
52 |
-
latent_channels=1,
|
53 |
-
layers_per_block=2)
|
54 |
model = UNet2DModel(
|
55 |
-
sample_size=args.resolution,
|
56 |
in_channels=1,
|
57 |
out_channels=1,
|
58 |
layers_per_block=2,
|
@@ -75,10 +70,12 @@ def main(args):
|
|
75 |
),
|
76 |
)
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
optimizer = torch.optim.AdamW(
|
83 |
model.parameters(),
|
84 |
lr=args.learning_rate,
|
@@ -115,7 +112,13 @@ def main(args):
|
|
115 |
)
|
116 |
|
117 |
def transforms(examples):
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return {"input": images}
|
120 |
|
121 |
dataset.set_transform(transforms)
|
@@ -181,27 +184,42 @@ def main(args):
|
|
181 |
device=clean_images.device,
|
182 |
).long()
|
183 |
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
186 |
# (this is the forward diffusion process)
|
187 |
-
|
188 |
-
|
189 |
|
190 |
with accelerator.accumulate(model):
|
191 |
# Predict the noise residual
|
192 |
-
|
193 |
-
noise_pred = vqvae.decode(
|
194 |
loss = F.mse_loss(noise_pred, noise)
|
195 |
accelerator.backward(loss)
|
196 |
|
197 |
-
accelerator.
|
|
|
198 |
optimizer.step()
|
199 |
lr_scheduler.step()
|
200 |
if args.use_ema:
|
201 |
ema_model.step(model)
|
202 |
optimizer.zero_grad()
|
203 |
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
logs = {
|
206 |
"loss": loss.detach().item(),
|
207 |
"lr": lr_scheduler.get_last_lr()[0],
|
@@ -211,7 +229,6 @@ def main(args):
|
|
211 |
logs["ema_decay"] = ema_model.decay
|
212 |
progress_bar.set_postfix(**logs)
|
213 |
accelerator.log(logs, step=global_step)
|
214 |
-
global_step += 1
|
215 |
progress_bar.close()
|
216 |
|
217 |
accelerator.wait_for_everyone()
|
@@ -219,17 +236,19 @@ def main(args):
|
|
219 |
# Generate sample images for visual inspection
|
220 |
if accelerator.is_main_process:
|
221 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
233 |
|
234 |
# save the model
|
235 |
if args.push_to_hub:
|
@@ -325,6 +344,14 @@ if __name__ == "__main__":
|
|
325 |
parser.add_argument("--hop_length", type=int, default=512)
|
326 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
327 |
parser.add_argument("--start_epoch", type=int, default=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
args = parser.parse_args()
|
330 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
|
35 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
36 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
37 |
accelerator = Accelerator(
|
38 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
39 |
mixed_precision=args.mixed_precision,
|
40 |
log_with="tensorboard",
|
41 |
logging_dir=logging_dir,
|
42 |
)
|
43 |
|
44 |
+
if args.vae is not None:
|
45 |
+
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
46 |
+
|
47 |
if args.from_pretrained is not None:
|
48 |
+
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
|
|
|
|
|
|
49 |
else:
|
|
|
|
|
|
|
|
|
|
|
50 |
model = UNet2DModel(
|
|
|
51 |
in_channels=1,
|
52 |
out_channels=1,
|
53 |
layers_per_block=2,
|
|
|
70 |
),
|
71 |
)
|
72 |
|
73 |
+
if args.scheduler == "ddpm":
|
74 |
+
noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
|
75 |
+
tensor_format="pt")
|
76 |
+
else:
|
77 |
+
noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
|
78 |
+
tensor_format="pt")
|
79 |
optimizer = torch.optim.AdamW(
|
80 |
model.parameters(),
|
81 |
lr=args.learning_rate,
|
|
|
112 |
)
|
113 |
|
114 |
def transforms(examples):
|
115 |
+
if args.vae is not None:
|
116 |
+
images = [
|
117 |
+
augmentations(image).convert("RGB")
|
118 |
+
for image in examples["image"]
|
119 |
+
]
|
120 |
+
else:
|
121 |
+
images = [augmentations(image) for image in examples["image"]]
|
122 |
return {"input": images}
|
123 |
|
124 |
dataset.set_transform(transforms)
|
|
|
184 |
device=clean_images.device,
|
185 |
).long()
|
186 |
|
187 |
+
if args.vae is not None:
|
188 |
+
with torch.no_grad():
|
189 |
+
clean_images = vqvae.encode(
|
190 |
+
clean_images).latent_dist.sample()
|
191 |
+
|
192 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
193 |
# (this is the forward diffusion process)
|
194 |
+
noisy_images = noise_scheduler.add_noise(clean_images, noise,
|
195 |
+
timesteps)
|
196 |
|
197 |
with accelerator.accumulate(model):
|
198 |
# Predict the noise residual
|
199 |
+
images = model(noisy_images, timesteps)["sample"]
|
200 |
+
noise_pred = vqvae.decode(images)["sample"]
|
201 |
loss = F.mse_loss(noise_pred, noise)
|
202 |
accelerator.backward(loss)
|
203 |
|
204 |
+
if accelerator.sync_gradients:
|
205 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
206 |
optimizer.step()
|
207 |
lr_scheduler.step()
|
208 |
if args.use_ema:
|
209 |
ema_model.step(model)
|
210 |
optimizer.zero_grad()
|
211 |
|
212 |
+
if args.vae is not None:
|
213 |
+
with torch.no_grad():
|
214 |
+
images = [
|
215 |
+
image.convert('L')
|
216 |
+
for image in vqvae.decode(images)["sample"]
|
217 |
+
]
|
218 |
+
|
219 |
+
if accelerator.sync_gradients:
|
220 |
+
progress_bar.update(1)
|
221 |
+
global_step += 1
|
222 |
+
|
223 |
logs = {
|
224 |
"loss": loss.detach().item(),
|
225 |
"lr": lr_scheduler.get_last_lr()[0],
|
|
|
229 |
logs["ema_decay"] = ema_model.decay
|
230 |
progress_bar.set_postfix(**logs)
|
231 |
accelerator.log(logs, step=global_step)
|
|
|
232 |
progress_bar.close()
|
233 |
|
234 |
accelerator.wait_for_everyone()
|
|
|
236 |
# Generate sample images for visual inspection
|
237 |
if accelerator.is_main_process:
|
238 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
239 |
+
if args.vae is not None:
|
240 |
+
pipeline = LDMPipeline(
|
241 |
+
unet=accelerator.unwrap_model(
|
242 |
+
ema_model.averaged_model if args.use_ema else model),
|
243 |
+
vqvae=vqvae,
|
244 |
+
scheduler=noise_scheduler,
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
pipeline = DDPMPipeline(
|
248 |
+
unet=accelerator.unwrap_model(
|
249 |
+
ema_model.averaged_model if args.use_ema else model),
|
250 |
+
scheduler=noise_scheduler,
|
251 |
+
)
|
252 |
|
253 |
# save the model
|
254 |
if args.push_to_hub:
|
|
|
344 |
parser.add_argument("--hop_length", type=int, default=512)
|
345 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
346 |
parser.add_argument("--start_epoch", type=int, default=0)
|
347 |
+
parser.add_argument("--scheduler",
|
348 |
+
type=str,
|
349 |
+
default="ddpm",
|
350 |
+
help="ddpm or ddim")
|
351 |
+
parser.add_argument("--vae",
|
352 |
+
type=str,
|
353 |
+
default=None,
|
354 |
+
help="pretrained VAE model for latent diffusion")
|
355 |
|
356 |
args = parser.parse_args()
|
357 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
scripts/train_vae.py
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
# grayscale
|
7 |
# add vae to train_uncond (no_grad)
|
8 |
# update README
|
|
|
9 |
|
10 |
import os
|
11 |
import argparse
|
|
|
6 |
# grayscale
|
7 |
# add vae to train_uncond (no_grad)
|
8 |
# update README
|
9 |
+
# merge in changes to train_unconditional
|
10 |
|
11 |
import os
|
12 |
import argparse
|