mattricesound commited on
Commit
f07a3b6
1 Parent(s): 1efdf7d

Fix wandb on gpu

Browse files
Files changed (2) hide show
  1. models.py +13 -6
  2. train.py +1 -1
models.py CHANGED
@@ -71,23 +71,26 @@ class OpenUnmixModel(pl.LightningModule):
71
  sample_rate=SAMPLE_RATE,
72
  n_fft=self.n_fft,
73
  n_hop=self.hop_length,
74
- )
75
  outputs = s(x).squeeze(1)
76
  log_wandb_audio_batch(
 
77
  id="sample",
78
- samples=x,
79
  sampling_rate=SAMPLE_RATE,
80
  caption=f"Epoch {self.current_epoch}",
81
  )
82
  log_wandb_audio_batch(
 
83
  id="prediction",
84
- samples=outputs,
85
  sampling_rate=SAMPLE_RATE,
86
  caption=f"Epoch {self.current_epoch}",
87
  )
88
  log_wandb_audio_batch(
 
89
  id="target",
90
- samples=target,
91
  sampling_rate=SAMPLE_RATE,
92
  caption=f"Epoch {self.current_epoch}",
93
  )
@@ -146,12 +149,16 @@ class DiffusionGenerationModel(pl.LightningModule):
146
 
147
 
148
  def log_wandb_audio_batch(
149
- id: str, samples: Tensor, sampling_rate: int, caption: str = ""
 
 
 
 
150
  ):
151
  num_items = samples.shape[0]
152
  samples = rearrange(samples, "b c t -> b t c")
153
  for idx in range(num_items):
154
- wandb.log(
155
  {
156
  f"{id}_{idx}": wandb.Audio(
157
  samples[idx].cpu().numpy(),
 
71
  sample_rate=SAMPLE_RATE,
72
  n_fft=self.n_fft,
73
  n_hop=self.hop_length,
74
+ ).to(self.device)
75
  outputs = s(x).squeeze(1)
76
  log_wandb_audio_batch(
77
+ logger=self.logger,
78
  id="sample",
79
+ samples=x.cpu(),
80
  sampling_rate=SAMPLE_RATE,
81
  caption=f"Epoch {self.current_epoch}",
82
  )
83
  log_wandb_audio_batch(
84
+ logger=self.logger,
85
  id="prediction",
86
+ samples=outputs.cpu(),
87
  sampling_rate=SAMPLE_RATE,
88
  caption=f"Epoch {self.current_epoch}",
89
  )
90
  log_wandb_audio_batch(
91
+ logger=self.loggger,
92
  id="target",
93
+ samples=target.cpu(),
94
  sampling_rate=SAMPLE_RATE,
95
  caption=f"Epoch {self.current_epoch}",
96
  )
 
149
 
150
 
151
  def log_wandb_audio_batch(
152
+ logger: pl.loggers.WandbLogger,
153
+ id: str,
154
+ samples: Tensor,
155
+ sampling_rate: int,
156
+ caption: str = "",
157
  ):
158
  num_items = samples.shape[0]
159
  samples = rearrange(samples, "b c t -> b t c")
160
  for idx in range(num_items):
161
+ logger.experiment.log(
162
  {
163
  f"{id}_{idx}": wandb.Audio(
164
  samples[idx].cpu().numpy(),
train.py CHANGED
@@ -12,7 +12,7 @@ TRAIN_SPLIT = 0.8
12
 
13
  def main():
14
  wandb_logger = WandbLogger(project="RemFX", save_dir="./")
15
- trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
16
  guitfx = GuitarFXDataset(
17
  root="./data/egfx",
18
  sample_rate=SAMPLE_RATE,
 
12
 
13
  def main():
14
  wandb_logger = WandbLogger(project="RemFX", save_dir="./")
15
+ trainer = pl.Trainer(logger=wandb_logger, max_epochs=100)
16
  guitfx = GuitarFXDataset(
17
  root="./data/egfx",
18
  sample_rate=SAMPLE_RATE,