teticio commited on
Commit
7aaaf62
1 Parent(s): 8aa7c27

added vae notebook

Browse files
notebooks/test_vae.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
scripts/train_unconditional.py CHANGED
@@ -11,7 +11,7 @@ from accelerate import Accelerator
11
  from accelerate.logging import get_logger
12
  from datasets import load_from_disk, load_dataset
13
  from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
14
- DDIMScheduler, VQModel)
15
  from diffusers.hub_utils import init_git_repo, push_to_hub
16
  from diffusers.optimization import get_scheduler
17
  from diffusers.training_utils import EMAModel
@@ -46,11 +46,11 @@ def main(args):
46
  vqvae = pretrained.vqvae
47
  model = pretrained.unet
48
  else:
49
- vqvae = VQModel(sample_size=args.resolution,
50
- in_channels=1,
51
- out_channels=1,
52
- latent_channels=1,
53
- layers_per_block=2)
54
  model = UNet2DModel(
55
  sample_size=args.resolution,
56
  in_channels=1,
 
11
  from accelerate.logging import get_logger
12
  from datasets import load_from_disk, load_dataset
13
  from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
14
+ DDIMScheduler, AutoencoderKL)
15
  from diffusers.hub_utils import init_git_repo, push_to_hub
16
  from diffusers.optimization import get_scheduler
17
  from diffusers.training_utils import EMAModel
 
46
  vqvae = pretrained.vqvae
47
  model = pretrained.unet
48
  else:
49
+ vqvae = AutoencoderKL(sample_size=args.resolution,
50
+ in_channels=1,
51
+ out_channels=1,
52
+ latent_channels=1,
53
+ layers_per_block=2)
54
  model = UNet2DModel(
55
  sample_size=args.resolution,
56
  in_channels=1,
scripts/train_vae.py CHANGED
@@ -152,7 +152,7 @@ if __name__ == "__main__":
152
  trainer_opt,
153
  resume_from_checkpoint=args.resume_from_checkpoint,
154
  callbacks=[
155
- ImageLogger(),
156
  HFModelCheckpoint(ldm_config=config,
157
  hf_checkpoint=args.hf_checkpoint_dir,
158
  dirpath=args.ldm_checkpoint_dir,
 
152
  trainer_opt,
153
  resume_from_checkpoint=args.resume_from_checkpoint,
154
  callbacks=[
155
+ ImageLogger(every=10),
156
  HFModelCheckpoint(ldm_config=config,
157
  hf_checkpoint=args.hf_checkpoint_dir,
158
  dirpath=args.ldm_checkpoint_dir,