teticio commited on
Commit
13aa297
1 Parent(s): 1ef9d1c

take channels into account

Browse files
config/ldm_autoencoder_kl.yaml CHANGED
@@ -19,7 +19,7 @@ model:
19
  in_channels: 3
20
  out_ch: 3
21
  ch: 128
22
- ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
23
  num_res_blocks: 2
24
  attn_resolutions: [ ]
25
  dropout: 0.0
 
19
  in_channels: 3
20
  out_ch: 3
21
  ch: 128
22
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
23
  num_res_blocks: 2
24
  attn_resolutions: [ ]
25
  dropout: 0.0
scripts/train_vae.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  # TODO
5
  # grayscale
 
6
 
7
  import os
8
  import argparse
@@ -27,8 +28,9 @@ from audiodiffusion.utils import convert_ldm_to_hf_vae
27
 
28
  class AudioDiffusion(Dataset):
29
 
30
- def __init__(self, model_id):
31
  super().__init__()
 
32
  if os.path.exists(model_id):
33
  self.hf_dataset = load_from_disk(model_id)['train']
34
  else:
@@ -38,7 +40,9 @@ class AudioDiffusion(Dataset):
38
  return len(self.hf_dataset)
39
 
40
  def __getitem__(self, idx):
41
- image = self.hf_dataset[idx]['image'].convert('RGB')
 
 
42
  image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
43
  (image.height, image.width, 3))
44
  image = ((image / 255) * 2 - 1)
@@ -47,10 +51,10 @@ class AudioDiffusion(Dataset):
47
 
48
  class AudioDiffusionDataModule(pl.LightningDataModule):
49
 
50
- def __init__(self, model_id, batch_size):
51
  super().__init__()
52
  self.batch_size = batch_size
53
- self.dataset = AudioDiffusion(model_id)
54
  self.num_workers = 1
55
 
56
  def train_dataloader(self):
@@ -61,12 +65,13 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
61
 
62
  class ImageLogger(Callback):
63
 
64
- def __init__(self, every=1000, resolution=256, hop_length=512):
65
  super().__init__()
66
  self.mel = Mel(x_res=resolution,
67
  y_res=resolution,
68
  hop_length=hop_length)
69
  self.every = every
 
70
 
71
  @rank_zero_only
72
  def log_images_and_audios(self, pl_module, batch):
@@ -89,7 +94,8 @@ class ImageLogger(Callback):
89
  255).round().astype("uint8").transpose(0, 2, 3, 1)
90
  for _, image in enumerate(images[k]):
91
  audio = self.mel.image_to_audio(
92
- Image.fromarray(image, mode='RGB').convert('L'))
 
93
  pl_module.logger.experiment.add_audio(
94
  tag + f"/{_}",
95
  normalize(audio),
@@ -140,9 +146,17 @@ if __name__ == "__main__":
140
  "--gradient_accumulation_steps",
141
  type=int,
142
  default=1)
 
 
143
  args = parser.parse_args()
144
 
145
  config = OmegaConf.load(args.ldm_config_file)
 
 
 
 
 
 
146
  lightning_config = config.pop("lightning", OmegaConf.create())
147
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
148
  trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
@@ -151,7 +165,9 @@ if __name__ == "__main__":
151
  trainer_opt,
152
  resume_from_checkpoint=args.resume_from_checkpoint,
153
  callbacks=[
154
- ImageLogger(),
 
 
155
  HFModelCheckpoint(ldm_config=config,
156
  hf_checkpoint=args.hf_checkpoint_dir,
157
  dirpath=args.ldm_checkpoint_dir,
@@ -159,8 +175,4 @@ if __name__ == "__main__":
159
  verbose=True,
160
  save_last=True)
161
  ])
162
- model = instantiate_from_config(config.model)
163
- model.learning_rate = config.model.base_learning_rate
164
- data = AudioDiffusionDataModule(args.dataset_name,
165
- batch_size=args.batch_size)
166
  trainer.fit(model, data)
 
3
 
4
  # TODO
5
  # grayscale
6
+ # docstrings
7
 
8
  import os
9
  import argparse
 
28
 
29
  class AudioDiffusion(Dataset):
30
 
31
+ def __init__(self, model_id, channels=3):
32
  super().__init__()
33
+ self.channels = channels
34
  if os.path.exists(model_id):
35
  self.hf_dataset = load_from_disk(model_id)['train']
36
  else:
 
40
  return len(self.hf_dataset)
41
 
42
  def __getitem__(self, idx):
43
+ image = self.hf_dataset[idx]['image']
44
+ if self.channels == 3:
45
+ image = image.convert('RGB')
46
  image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
47
  (image.height, image.width, 3))
48
  image = ((image / 255) * 2 - 1)
 
51
 
52
  class AudioDiffusionDataModule(pl.LightningDataModule):
53
 
54
+ def __init__(self, model_id, batch_size, channels):
55
  super().__init__()
56
  self.batch_size = batch_size
57
+ self.dataset = AudioDiffusion(model_id=model_id, channels=channels)
58
  self.num_workers = 1
59
 
60
  def train_dataloader(self):
 
65
 
66
  class ImageLogger(Callback):
67
 
68
+ def __init__(self, every=1000, channels=3, resolution=256, hop_length=512):
69
  super().__init__()
70
  self.mel = Mel(x_res=resolution,
71
  y_res=resolution,
72
  hop_length=hop_length)
73
  self.every = every
74
+ self.channels = channels
75
 
76
  @rank_zero_only
77
  def log_images_and_audios(self, pl_module, batch):
 
94
  255).round().astype("uint8").transpose(0, 2, 3, 1)
95
  for _, image in enumerate(images[k]):
96
  audio = self.mel.image_to_audio(
97
+ Image.fromarray(image, mode='RGB').convert('L') if self.
98
+ channels == 3 else Image.fromarray(image[0]))
99
  pl_module.logger.experiment.add_audio(
100
  tag + f"/{_}",
101
  normalize(audio),
 
146
  "--gradient_accumulation_steps",
147
  type=int,
148
  default=1)
149
+ parser.add_argument("--resolution", type=int, default=256)
150
+ parser.add_argument("--hop_length", type=int, default=512)
151
  args = parser.parse_args()
152
 
153
  config = OmegaConf.load(args.ldm_config_file)
154
+ model = instantiate_from_config(config.model)
155
+ model.learning_rate = config.model.base_learning_rate
156
+ data = AudioDiffusionDataModule(
157
+ model_id=args.dataset_name,
158
+ batch_size=args.batch_size,
159
+ channels=config.model.params.ddconfig.in_channels)
160
  lightning_config = config.pop("lightning", OmegaConf.create())
161
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
162
  trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
 
165
  trainer_opt,
166
  resume_from_checkpoint=args.resume_from_checkpoint,
167
  callbacks=[
168
+ ImageLogger(channels=config.model.params.ddconfig.out_ch,
169
+ resolution=args.resolution,
170
+ hop_length=args.hop_length),
171
  HFModelCheckpoint(ldm_config=config,
172
  hf_checkpoint=args.hf_checkpoint_dir,
173
  dirpath=args.ldm_checkpoint_dir,
 
175
  verbose=True,
176
  save_last=True)
177
  ])
 
 
 
 
178
  trainer.fit(model, data)