teticio commited on
Commit
d76bdef
1 Parent(s): 3e8b723
Files changed (1) hide show
  1. train_vae.py +30 -122
train_vae.py CHANGED
@@ -4,7 +4,6 @@
4
 
5
  # TODO
6
  # grayscale
7
- # log audio
8
  # convert to huggingface / train huggingface
9
 
10
  import os
@@ -57,134 +56,46 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
57
  num_workers=self.num_workers)
58
 
59
 
60
- # from https://github.com/CompVis/stable-diffusion/blob/main/main.py
61
  class ImageLogger(Callback):
62
 
63
- def __init__(self,
64
- batch_frequency,
65
- max_images,
66
- clamp=True,
67
- increase_log_steps=True,
68
- rescale=True,
69
- disabled=False,
70
- log_on_batch_idx=False,
71
- log_first_step=False,
72
- log_images_kwargs=None,
73
- resolution=256,
74
- hop_length=512):
75
  super().__init__()
76
  self.mel = Mel(x_res=resolution,
77
  y_res=resolution,
78
  hop_length=hop_length)
79
- self.rescale = rescale
80
- self.batch_freq = batch_frequency
81
- self.max_images = max_images
82
- self.logger_log_images = {
83
- pl.loggers.TensorBoardLogger: self._testtube,
84
- }
85
- self.log_steps = [
86
- 2**n for n in range(int(np.log2(self.batch_freq)) + 1)
87
- ]
88
- if not increase_log_steps:
89
- self.log_steps = [self.batch_freq]
90
- self.clamp = clamp
91
- self.disabled = disabled
92
- self.log_on_batch_idx = log_on_batch_idx
93
- self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
94
- self.log_first_step = log_first_step
95
-
96
- #@rank_zero_only
97
- def _testtube(self, pl_module, images, batch_idx, split):
98
  for k in images:
99
- images_ = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
100
- grid = torchvision.utils.make_grid(images_)
 
 
101
 
102
- tag = f"{split}/{k}"
103
  pl_module.logger.experiment.add_image(
104
  tag, grid, global_step=pl_module.global_step)
105
 
106
- for _, image in enumerate(images_):
107
- image = (images_.numpy() *
108
  255).round().astype("uint8").transpose(0, 2, 3, 1)
 
109
  audio = self.mel.image_to_audio(
110
- Image.fromarray(image[0], mode='RGB').convert('L'))
111
  pl_module.logger.experiment.add_audio(
112
  tag + f"/{_}",
113
  normalize(audio),
114
  global_step=pl_module.global_step,
115
  sample_rate=self.mel.get_sample_rate())
116
 
117
- #@rank_zero_only
118
- def log_local(self, save_dir, split, images, global_step, current_epoch,
119
- batch_idx):
120
- root = os.path.join(save_dir, "images", split)
121
- for k in images:
122
- grid = torchvision.utils.make_grid(images[k], nrow=4)
123
- if self.rescale:
124
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
125
- grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
126
- grid = grid.numpy()
127
- grid = (grid * 255).astype(np.uint8)
128
- filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
129
- k, global_step, current_epoch, batch_idx)
130
- path = os.path.join(root, filename)
131
- os.makedirs(os.path.split(path)[0], exist_ok=True)
132
- Image.fromarray(grid).save(path)
133
-
134
- def log_img(self, pl_module, batch, batch_idx, split="train"):
135
- check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
136
- if (self.check_frequency(check_idx)
137
- and # batch_idx % self.batch_freq == 0
138
- hasattr(pl_module, "log_images") and
139
- callable(pl_module.log_images) and self.max_images > 0):
140
- logger = type(pl_module.logger)
141
-
142
- is_train = pl_module.training
143
- if is_train:
144
- pl_module.eval()
145
-
146
- with torch.no_grad():
147
- images = pl_module.log_images(batch,
148
- split=split,
149
- **self.log_images_kwargs)
150
-
151
- for k in images:
152
- N = min(images[k].shape[0], self.max_images)
153
- images[k] = images[k][:N]
154
- if isinstance(images[k], torch.Tensor):
155
- images[k] = images[k].detach().cpu()
156
- if self.clamp:
157
- images[k] = torch.clamp(images[k], -1., 1.)
158
-
159
- #self.log_local(pl_module.logger.save_dir, split, images,
160
- # pl_module.global_step, pl_module.current_epoch,
161
- # batch_idx)
162
-
163
- logger_log_images = self.logger_log_images.get(
164
- logger, lambda *args, **kwargs: None)
165
- logger_log_images(pl_module, images, pl_module.global_step, split)
166
-
167
- if is_train:
168
- pl_module.train()
169
-
170
- def check_frequency(self, check_idx):
171
- if ((check_idx % self.batch_freq) == 0 or
172
- (check_idx in self.log_steps)) and (check_idx > 0
173
- or self.log_first_step):
174
- try:
175
- self.log_steps.pop(0)
176
- except IndexError as e:
177
- #print(e)
178
- pass
179
- return True
180
- return False
181
-
182
- def on_train_batch_end(self, trainer, pl_module, outputs, batch,
183
- batch_idx):
184
- if not self.disabled and (pl_module.global_step > 0
185
- or self.log_first_step):
186
- self.log_img(pl_module, batch, batch_idx, split="train")
187
-
188
 
189
  if __name__ == "__main__":
190
  parser = argparse.ArgumentParser(description="Train VAE using ldm.")
@@ -195,18 +106,15 @@ if __name__ == "__main__":
195
  lightning_config = config.pop("lightning", OmegaConf.create())
196
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
197
  trainer_opt = argparse.Namespace(**trainer_config)
198
- trainer = Trainer.from_argparse_args(
199
- trainer_opt,
200
- callbacks=[
201
- ImageLogger(batch_frequency=1000,
202
- max_images=8,
203
- increase_log_steps=False,
204
- log_on_batch_idx=True),
205
- ModelCheckpoint(dirpath='checkpoints',
206
- filename='{epoch:06}',
207
- verbose=True,
208
- save_last=True)
209
- ])
210
  model = instantiate_from_config(config.model)
211
  model.learning_rate = config.model.base_learning_rate
212
  data = AudioDiffusionDataModule('teticio/audio-diffusion-256',
 
4
 
5
  # TODO
6
  # grayscale
 
7
  # convert to huggingface / train huggingface
8
 
9
  import os
 
56
  num_workers=self.num_workers)
57
 
58
 
 
59
  class ImageLogger(Callback):
60
 
61
+ def __init__(self, every=1000, resolution=256, hop_length=512):
 
 
 
 
 
 
 
 
 
 
 
62
  super().__init__()
63
  self.mel = Mel(x_res=resolution,
64
  y_res=resolution,
65
  hop_length=hop_length)
66
+ self.every = every
67
+
68
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch,
69
+ batch_idx):
70
+ if (batch_idx + 1) % self.every != 0:
71
+ return
72
+
73
+ pl_module.eval()
74
+ with torch.no_grad():
75
+ images = pl_module.log_images(batch, split='train')
76
+ pl_module.train()
77
+
 
 
 
 
 
 
 
78
  for k in images:
79
+ images[k] = images[k].detach().cpu()
80
+ images[k] = torch.clamp(images[k], -1., 1.)
81
+ images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
82
+ grid = torchvision.utils.make_grid(images[k])
83
 
84
+ tag = f"train/{k}"
85
  pl_module.logger.experiment.add_image(
86
  tag, grid, global_step=pl_module.global_step)
87
 
88
+ images[k] = (images[k].numpy() *
 
89
  255).round().astype("uint8").transpose(0, 2, 3, 1)
90
+ for _, image in enumerate(images[k]):
91
  audio = self.mel.image_to_audio(
92
+ Image.fromarray(image, mode='RGB').convert('L'))
93
  pl_module.logger.experiment.add_audio(
94
  tag + f"/{_}",
95
  normalize(audio),
96
  global_step=pl_module.global_step,
97
  sample_rate=self.mel.get_sample_rate())
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  parser = argparse.ArgumentParser(description="Train VAE using ldm.")
 
106
  lightning_config = config.pop("lightning", OmegaConf.create())
107
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
108
  trainer_opt = argparse.Namespace(**trainer_config)
109
+ trainer = Trainer.from_argparse_args(trainer_opt,
110
+ callbacks=[
111
+ ImageLogger(),
112
+ ModelCheckpoint(
113
+ dirpath='checkpoints',
114
+ filename='{epoch:06}',
115
+ verbose=True,
116
+ save_last=True)
117
+ ])
 
 
 
118
  model = instantiate_from_config(config.model)
119
  model.learning_rate = config.model.base_learning_rate
120
  data = AudioDiffusionDataModule('teticio/audio-diffusion-256',