waidhoferj commited on
Commit
797a86a
β€’
1 Parent(s): 6ba247f

refactored loggers

Browse files
TODO.md CHANGED
@@ -12,6 +12,8 @@
12
  - βœ… Download songs from [Best Ballroom](https://www.youtube.com/channel/UC0bYSnzAFMwPiEjmVsrvmRg)
13
 
14
  - βœ… fix nan values
 
 
15
 
16
  ## Notes
17
 
 
12
  - βœ… Download songs from [Best Ballroom](https://www.youtube.com/channel/UC0bYSnzAFMwPiEjmVsrvmRg)
13
 
14
  - βœ… fix nan values
15
+ - Try higher mels (224) and more ffts (2048)
16
+ - Verify random sample of dataset outputs by hand.
17
 
18
  ## Notes
19
 
audio_utils.py DELETED
@@ -1,42 +0,0 @@
1
- import librosa
2
- from IPython.display import Audio, display
3
- import matplotlib.pyplot as plt
4
- import torch
5
- SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav"
6
-
7
- SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
8
-
9
- def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
10
- spec = spec.squeeze(0)
11
- spec = spec.numpy()
12
- fig, axs = plt.subplots(1, 1)
13
- axs.set_title(title or "Spectrogram (db)")
14
- axs.set_ylabel(ylabel)
15
- axs.set_xlabel("frame")
16
- im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
17
- if xmax:
18
- axs.set_xlim((0, xmax))
19
- fig.colorbar(im, ax=axs)
20
- plt.show(block=False)
21
-
22
- def play_audio(waveform, sample_rate):
23
- waveform = waveform.numpy()
24
-
25
- num_channels, num_frames = waveform.shape
26
- if num_channels == 1:
27
- display(Audio(waveform[0], rate=sample_rate))
28
- elif num_channels == 2:
29
- display(Audio((waveform[0], waveform[1]), rate=sample_rate))
30
- else:
31
- raise ValueError("Waveform with more than 2 channels are not supported.")
32
-
33
- def get_rir_sample(path, resample=None, processed=False):
34
- rir_raw, sample_rate = torch.load(path)
35
- if not processed:
36
- return rir_raw, sample_rate
37
- rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
38
- rir = rir / torch.norm(rir, p=2)
39
- rir = torch.flip(rir, [1])
40
- return rir, sample_rate
41
-
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/config/train_local.yaml CHANGED
@@ -19,6 +19,7 @@ dance_ids: &dance_ids
19
  data_module:
20
  batch_size: 64
21
  num_workers: 10
 
22
  test_proportion: 0.2
23
 
24
  datasets:
@@ -55,4 +56,6 @@ trainer:
55
 
56
  training_environment:
57
  learning_rate: 0.00053
58
- log_spectrograms: False
 
 
 
19
  data_module:
20
  batch_size: 64
21
  num_workers: 10
22
+ data_subset: 0.1
23
  test_proportion: 0.2
24
 
25
  datasets:
 
56
 
57
  training_environment:
58
  learning_rate: 0.00053
59
+ loggers:
60
+ models.training_environment.SpectrogramLogger:
61
+ frequency: 100
models/training_environment.py CHANGED
@@ -1,6 +1,7 @@
 
1
  from models.utils import calculate_metrics
2
 
3
-
4
  import pytorch_lightning as pl
5
  import torch
6
  import torch.nn as nn
@@ -13,7 +14,6 @@ class TrainingEnvironment(pl.LightningModule):
13
  criterion: nn.Module,
14
  config: dict,
15
  learning_rate=1e-4,
16
- log_spectrograms=False,
17
  *args,
18
  **kwargs,
19
  ):
@@ -21,7 +21,9 @@ class TrainingEnvironment(pl.LightningModule):
21
  self.model = model
22
  self.criterion = criterion
23
  self.learning_rate = learning_rate
24
- self.log_spectrograms = log_spectrograms
 
 
25
  self.config = config
26
  self.has_multi_label_predictions = (
27
  not type(criterion).__name__ == "CrossEntropyLoss"
@@ -48,15 +50,9 @@ class TrainingEnvironment(pl.LightningModule):
48
  multi_label=self.has_multi_label_predictions,
49
  )
50
  self.log_dict(metrics, prog_bar=True)
51
- # Log spectrograms
52
- if self.log_spectrograms and batch_index % 100 == 0:
53
- tensorboard = self.logger.experiment
54
- img_index = torch.randint(0, len(features), (1,)).item()
55
- img = features[img_index][0]
56
- img = (img - img.min()) / (img.max() - img.min())
57
- tensorboard.add_image(
58
- f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
59
- )
60
  return loss
61
 
62
  def validation_step(
@@ -88,3 +84,36 @@ class TrainingEnvironment(pl.LightningModule):
88
  "lr_scheduler": scheduler,
89
  "monitor": "val/loss",
90
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
  from models.utils import calculate_metrics
3
 
4
+ from abc import ABC, abstractmethod
5
  import pytorch_lightning as pl
6
  import torch
7
  import torch.nn as nn
 
14
  criterion: nn.Module,
15
  config: dict,
16
  learning_rate=1e-4,
 
17
  *args,
18
  **kwargs,
19
  ):
 
21
  self.model = model
22
  self.criterion = criterion
23
  self.learning_rate = learning_rate
24
+ self.experiment_loggers = load_loggers(
25
+ config["training_environment"].get("loggers", {})
26
+ )
27
  self.config = config
28
  self.has_multi_label_predictions = (
29
  not type(criterion).__name__ == "CrossEntropyLoss"
 
50
  multi_label=self.has_multi_label_predictions,
51
  )
52
  self.log_dict(metrics, prog_bar=True)
53
+ experiment = self.logger.experiment
54
+ for logger in self.experiment_loggers:
55
+ logger.step(experiment, batch_index, features, labels)
 
 
 
 
 
 
56
  return loss
57
 
58
  def validation_step(
 
84
  "lr_scheduler": scheduler,
85
  "monitor": "val/loss",
86
  }
87
+
88
+
89
+ class ExperimentLogger(ABC):
90
+ @abstractmethod
91
+ def step(self, experiment, data):
92
+ pass
93
+
94
+
95
+ class SpectrogramLogger(ExperimentLogger):
96
+ def __init__(self, frequency=100) -> None:
97
+ self.frequency = frequency
98
+ self.counter = 0
99
+
100
+ def step(self, experiment, batch_index, x, label):
101
+ if self.counter == self.frequency:
102
+ self.counter = 0
103
+ img_index = torch.randint(0, len(x), (1,)).item()
104
+ img = x[img_index][0]
105
+ img = (img - img.min()) / (img.max() - img.min())
106
+ experiment.add_image(
107
+ f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
108
+ )
109
+ self.counter += 1
110
+
111
+
112
+ def load_loggers(logger_config: dict) -> list[ExperimentLogger]:
113
+ loggers = []
114
+ for logger_path, kwargs in logger_config.items():
115
+ module_name, class_name = logger_path.rsplit(".", 1)
116
+ module = importlib.import_module(module_name)
117
+ Logger = getattr(module, class_name)
118
+ loggers.append(Logger(**kwargs))
119
+ return loggers