mattricesound commited on
Commit
a89496d
β€’
1 Parent(s): abb9ffa

Refactor to use hydra

Browse files
.gitignore CHANGED
@@ -6,4 +6,6 @@ data/
6
  .DS_Store
7
  __pycache__/
8
  lightning_logs/
9
- RemFX/
 
 
 
6
  .DS_Store
7
  __pycache__/
8
  lightning_logs/
9
+ RemFX/
10
+ outputs/
11
+ logs/
config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - exp: null
4
+ seed: 12345
5
+ train: True
6
+ length: 262144
7
+ sample_rate: 22050
8
+ logs_dir: "./logs"
9
+ log_every_n_steps: 1000
10
+
11
+ callbacks:
12
+ model_checkpoint:
13
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
14
+ monitor: "valid_loss" # name of the logged metric which determines when model is improving
15
+ save_top_k: 1 # save k best models (determined by above metric)
16
+ save_last: True # additionaly always save model from last epoch
17
+ mode: "min" # can be "max" or "min"
18
+ verbose: False
19
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
20
+ filename: '{epoch:02d}-{valid_loss:.3f}'
21
+
22
+ datamodule:
23
+ _target_: datasets.Datamodule
24
+ dataset:
25
+ _target_: datasets.GuitarFXDataset
26
+ sample_rate: ${sample_rate}
27
+ root: ${oc.env:DATASET_ROOT}
28
+ length: ${length}
29
+ val_split: 0.2
30
+ batch_size: 16
31
+ num_workers: 8
32
+ pin_memory: True
33
+
34
+ logger:
35
+ _target_: pytorch_lightning.loggers.WandbLogger
36
+ project: ${oc.env:WANDB_PROJECT}
37
+ entity: ${oc.env:WANDB_ENTITY}
38
+ # offline: False # set True to store all logs only locally
39
+ job_type: "train"
40
+ group: ""
41
+ save_dir: "."
42
+
43
+ trainer:
44
+ _target_: pytorch_lightning.Trainer
45
+ precision: 32 # Precision used for tensors, default `32`
46
+ min_epochs: 0
47
+ max_epochs: -1
48
+ enable_model_summary: False
49
+ log_every_n_steps: 1 # Logs metrics every N batches
50
+ accumulate_grad_batches: 1
datasets.py CHANGED
@@ -1,10 +1,10 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
  import torchaudio
4
  import torchaudio.transforms as T
5
  import torch.nn.functional as F
6
  from pathlib import Path
7
- from typing import List
 
8
 
9
  # https://zenodo.org/record/7044411/
10
 
@@ -18,18 +18,19 @@ class GuitarFXDataset(Dataset):
18
  root: str,
19
  sample_rate: int,
20
  length: int = LENGTH,
21
- effect_type: List[str] = None,
22
  ):
23
  self.length = length
24
  self.wet_files = []
25
  self.dry_files = []
26
  self.labels = []
27
  self.root = Path(root)
28
- if effect_type is None:
29
- effect_type = [
 
30
  d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
31
  ]
32
- for i, effect in enumerate(effect_type):
33
  for pickup in Path(self.root / effect).iterdir():
34
  self.wet_files += sorted(list(pickup.glob("*.wav")))
35
  self.dry_files += sorted(
@@ -61,3 +62,50 @@ class GuitarFXDataset(Dataset):
61
  elif resampled_y.shape[-1] > self.length:
62
  resampled_y = resampled_y[:, : self.length]
63
  return (resampled_x, resampled_y, effect_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader, random_split
 
2
  import torchaudio
3
  import torchaudio.transforms as T
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
+ import pytorch_lightning as pl
7
+ from typing import Any, List
8
 
9
  # https://zenodo.org/record/7044411/
10
 
 
18
  root: str,
19
  sample_rate: int,
20
  length: int = LENGTH,
21
+ effect_types: List[str] = None,
22
  ):
23
  self.length = length
24
  self.wet_files = []
25
  self.dry_files = []
26
  self.labels = []
27
  self.root = Path(root)
28
+
29
+ if effect_types is None:
30
+ effect_types = [
31
  d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
32
  ]
33
+ for i, effect in enumerate(effect_types):
34
  for pickup in Path(self.root / effect).iterdir():
35
  self.wet_files += sorted(list(pickup.glob("*.wav")))
36
  self.dry_files += sorted(
 
62
  elif resampled_y.shape[-1] > self.length:
63
  resampled_y = resampled_y[:, : self.length]
64
  return (resampled_x, resampled_y, effect_label)
65
+
66
+
67
+ class Datamodule(pl.LightningDataModule):
68
+ def __init__(
69
+ self,
70
+ dataset,
71
+ *,
72
+ val_split: float,
73
+ batch_size: int,
74
+ num_workers: int,
75
+ pin_memory: bool = False,
76
+ **kwargs: int,
77
+ ) -> None:
78
+ super().__init__()
79
+ self.dataset = dataset
80
+ self.val_split = val_split
81
+ self.batch_size = batch_size
82
+ self.num_workers = num_workers
83
+ self.pin_memory = pin_memory
84
+ self.data_train: Any = None
85
+ self.data_val: Any = None
86
+
87
+ def setup(self, stage: Any = None) -> None:
88
+ split = [1.0 - self.val_split, self.val_split]
89
+ train_size = int(split[0] * len(self.dataset))
90
+ val_size = int(split[1] * len(self.dataset))
91
+ self.data_train, self.data_val = random_split(
92
+ self.dataset, [train_size, val_size]
93
+ )
94
+
95
+ def train_dataloader(self) -> DataLoader:
96
+ return DataLoader(
97
+ dataset=self.data_train,
98
+ batch_size=self.batch_size,
99
+ num_workers=self.num_workers,
100
+ pin_memory=self.pin_memory,
101
+ shuffle=True,
102
+ )
103
+
104
+ def val_dataloader(self) -> DataLoader:
105
+ return DataLoader(
106
+ dataset=self.data_val,
107
+ batch_size=self.batch_size,
108
+ num_workers=self.num_workers,
109
+ pin_memory=self.pin_memory,
110
+ shuffle=False,
111
+ )
exp/audio_diffusion.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: models.RemFXModel
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ network:
10
+ _target_: models.DiffusionGenerationModel
11
+ n_channels: 1
12
+ datamodule:
13
+ dataset:
14
+ effect_types: ["Clean"]
15
+ batch_size: 2
exp/demucs.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # @package _global_
exp/umx.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: models.RemFXModel
4
+ lr: 1e-4
5
+ lr_beta1: 0.95
6
+ lr_beta2: 0.999
7
+ lr_eps: 1e-6
8
+ lr_weight_decay: 1e-3
9
+ network:
10
+ _target_: models.OpenUnmixModel
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_channels: 1
14
+ alpha: 0.3
15
+ sample_rate: ${sample_rate}
16
+ datamodule:
17
+ dataset:
18
+ effect_types: ["RAT"]
main.py DELETED
@@ -1,19 +0,0 @@
1
- from audio_diffusion_pytorch import AudioDiffusionModel
2
- import torch
3
- from tqdm import tqdm
4
- import wandb
5
-
6
- model = AudioDiffusionModel(in_channels=1)
7
- wandb.init(project="RemFX", entity="mattricesound")
8
-
9
- x = torch.randn(2, 1, 2**18)
10
- for i in tqdm(range(100)):
11
- loss = model(x)
12
- loss.backward()
13
- if i % 10 == 0:
14
- print(loss)
15
- wandb.log({"loss": loss})
16
-
17
-
18
- noise = torch.randn(2, 1, 2**18)
19
- sampled = model.sample(noise=noise, num_steps=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
- from torch import Tensor
3
  import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import AudioDiffusionModel
 
7
 
8
  import sys
9
 
@@ -14,50 +15,49 @@ from umx.openunmix.model import OpenUnmix, Separator
14
  SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
15
 
16
 
17
- class OpenUnmixModel(pl.LightningModule):
18
  def __init__(
19
  self,
20
- n_fft: int = 2048,
21
- hop_length: int = 512,
22
- alpha: float = 0.3,
 
 
 
23
  ):
24
  super().__init__()
25
- self.model = OpenUnmix(
26
- nb_channels=1,
27
- nb_bins=n_fft // 2 + 1,
28
- )
29
- self.n_fft = n_fft
30
- self.hop_length = hop_length
31
- self.alpha = alpha
32
- window = torch.hann_window(n_fft)
33
- self.register_buffer("window", window)
34
 
35
- def forward(self, x: torch.Tensor):
36
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def training_step(self, batch, batch_idx):
39
- loss, _ = self.common_step(batch, batch_idx, mode="train")
40
  return loss
41
 
42
  def validation_step(self, batch, batch_idx):
43
- loss, Y = self.common_step(batch, batch_idx, mode="val")
44
- return loss, Y
45
 
46
  def common_step(self, batch, batch_idx, mode: str = "train"):
47
- x, target, label = batch
48
- X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
49
- Y = self(X)
50
- Y_hat = spectrogram(
51
- target, self.window, self.n_fft, self.hop_length, self.alpha
52
- )
53
- loss = torch.nn.functional.mse_loss(Y, Y_hat)
54
- self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
55
- return loss, Y
56
-
57
- def configure_optimizers(self):
58
- return torch.optim.Adam(
59
- self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
60
- )
61
 
62
  def on_validation_epoch_start(self):
63
  self.log_next = True
@@ -65,14 +65,7 @@ class OpenUnmixModel(pl.LightningModule):
65
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
66
  if self.log_next:
67
  x, target, label = batch
68
- s = Separator(
69
- target_models={"other": self.model},
70
- nb_channels=1,
71
- sample_rate=SAMPLE_RATE,
72
- n_fft=self.n_fft,
73
- n_hop=self.hop_length,
74
- ).to(self.device)
75
- outputs = s(x).squeeze(1)
76
  log_wandb_audio_batch(
77
  logger=self.logger,
78
  id="sample",
@@ -83,12 +76,12 @@ class OpenUnmixModel(pl.LightningModule):
83
  log_wandb_audio_batch(
84
  logger=self.logger,
85
  id="prediction",
86
- samples=outputs.cpu(),
87
  sampling_rate=SAMPLE_RATE,
88
  caption=f"Epoch {self.current_epoch}",
89
  )
90
  log_wandb_audio_batch(
91
- logger=self.loggger,
92
  id="target",
93
  samples=target.cpu(),
94
  sampling_rate=SAMPLE_RATE,
@@ -97,55 +90,65 @@ class OpenUnmixModel(pl.LightningModule):
97
  self.log_next = False
98
 
99
 
100
- class DiffusionGenerationModel(pl.LightningModule):
101
- def __init__(self, model: torch.nn.Module):
 
 
 
 
 
 
 
102
  super().__init__()
103
- self.model = model
 
 
 
 
 
104
 
105
- def forward(self, x: torch.Tensor):
106
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- def sample(self, *args, **kwargs) -> Tensor:
109
- return self.model.sample(*args, **kwargs)
 
 
 
 
110
 
111
- def training_step(self, batch, batch_idx):
112
- loss = self.common_step(batch, batch_idx, mode="train")
113
  return loss
114
 
115
- def validation_step(self, batch, batch_idx):
116
- loss = self.common_step(batch, batch_idx, mode="val")
117
 
118
- def common_step(self, batch, batch_idx, mode: str = "train"):
119
- x, target, label = batch
120
- loss = self(x)
121
- self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
122
- return loss
123
 
124
- def configure_optimizers(self):
125
- return torch.optim.Adam(
126
- self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
127
- )
128
-
129
- def on_validation_epoch_start(self):
130
- self.log_next = True
131
 
132
- def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
133
  x, target, label = batch
134
- if self.log_next:
135
- self.log_sample(x)
136
- self.log_next = False
137
 
138
- @torch.no_grad()
139
- def log_sample(self, batch, num_steps=10):
140
- # Get start diffusion noise
141
- noise = torch.randn(batch.shape, device=self.device)
142
- sampled = self.sample(noise=noise, num_steps=num_steps) # Suggested range: 2-50
143
- log_wandb_audio_batch(
144
- id="sample",
145
- samples=sampled,
146
- sampling_rate=SAMPLE_RATE,
147
- caption=f"Sampled in {num_steps} steps",
148
- )
149
 
150
 
151
  def log_wandb_audio_batch(
 
1
  import torch
2
+ from torch import Tensor, nn
3
  import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import AudioDiffusionModel
7
+ import auraloss
8
 
9
  import sys
10
 
 
15
  SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
16
 
17
 
18
+ class RemFXModel(pl.LightningModule):
19
  def __init__(
20
  self,
21
+ lr: float,
22
+ lr_beta1: float,
23
+ lr_beta2: float,
24
+ lr_eps: float,
25
+ lr_weight_decay: float,
26
+ network: nn.Module,
27
  ):
28
  super().__init__()
29
+ self.lr = lr
30
+ self.lr_beta1 = lr_beta1
31
+ self.lr_beta2 = lr_beta2
32
+ self.lr_eps = lr_eps
33
+ self.lr_weight_decay = lr_weight_decay
34
+ self.model = network
 
 
 
35
 
36
+ @property
37
+ def device(self):
38
+ return next(self.model.parameters()).device
39
+
40
+ def configure_optimizers(self):
41
+ optimizer = torch.optim.AdamW(
42
+ list(self.model.parameters()),
43
+ lr=self.lr,
44
+ betas=(self.lr_beta1, self.lr_beta2),
45
+ eps=self.lr_eps,
46
+ weight_decay=self.lr_weight_decay,
47
+ )
48
+ return optimizer
49
 
50
  def training_step(self, batch, batch_idx):
51
+ loss = self.common_step(batch, batch_idx, mode="train")
52
  return loss
53
 
54
  def validation_step(self, batch, batch_idx):
55
+ loss = self.common_step(batch, batch_idx, mode="valid")
 
56
 
57
  def common_step(self, batch, batch_idx, mode: str = "train"):
58
+ loss = self.model(batch)
59
+ self.log(f"{mode}_loss", loss)
60
+ return loss
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def on_validation_epoch_start(self):
63
  self.log_next = True
 
65
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
66
  if self.log_next:
67
  x, target, label = batch
68
+ y = self.model.sample(x)
 
 
 
 
 
 
 
69
  log_wandb_audio_batch(
70
  logger=self.logger,
71
  id="sample",
 
76
  log_wandb_audio_batch(
77
  logger=self.logger,
78
  id="prediction",
79
+ samples=y.cpu(),
80
  sampling_rate=SAMPLE_RATE,
81
  caption=f"Epoch {self.current_epoch}",
82
  )
83
  log_wandb_audio_batch(
84
+ logger=self.logger,
85
  id="target",
86
  samples=target.cpu(),
87
  sampling_rate=SAMPLE_RATE,
 
90
  self.log_next = False
91
 
92
 
93
+ class OpenUnmixModel(torch.nn.Module):
94
+ def __init__(
95
+ self,
96
+ n_fft: int = 2048,
97
+ hop_length: int = 512,
98
+ n_channels: int = 1,
99
+ alpha: float = 0.3,
100
+ sample_rate: int = 22050,
101
+ ):
102
  super().__init__()
103
+ self.n_channels = n_channels
104
+ self.n_fft = n_fft
105
+ self.hop_length = hop_length
106
+ self.alpha = alpha
107
+ window = torch.hann_window(n_fft)
108
+ self.register_buffer("window", window)
109
 
110
+ self.num_bins = self.n_fft // 2 + 1
111
+ self.sample_rate = sample_rate
112
+ self.model = OpenUnmix(
113
+ nb_channels=self.n_channels,
114
+ nb_bins=self.num_bins,
115
+ )
116
+ self.separator = Separator(
117
+ target_models={"other": self.model},
118
+ nb_channels=self.n_channels,
119
+ sample_rate=self.sample_rate,
120
+ n_fft=self.n_fft,
121
+ n_hop=self.hop_length,
122
+ )
123
+ self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
124
+ n_bins=self.num_bins, sample_rate=self.sample_rate
125
+ )
126
 
127
+ def forward(self, batch):
128
+ x, target, label = batch
129
+ X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
130
+ Y = self.model(X)
131
+ sep_out = self.separator(x).squeeze(1)
132
+ loss = self.loss_fn(sep_out, target)
133
 
 
 
134
  return loss
135
 
136
+ def sample(self, x: Tensor) -> Tensor:
137
+ return self.separator(x).squeeze(1)
138
 
 
 
 
 
 
139
 
140
+ class DiffusionGenerationModel(nn.Module):
141
+ def __init__(self, n_channels: int = 1):
142
+ super().__init__()
143
+ self.model = AudioDiffusionModel(in_channels=n_channels)
 
 
 
144
 
145
+ def forward(self, batch):
146
  x, target, label = batch
147
+ return self.model(x)
 
 
148
 
149
+ def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
150
+ noise = torch.randn(x.shape)
151
+ return self.model.sample(noise, num_steps=num_steps)
 
 
 
 
 
 
 
 
152
 
153
 
154
  def log_wandb_audio_batch(
Experiments.ipynb β†’ notebooks/Experiments.ipynb RENAMED
File without changes
diffusion_test.ipynb β†’ notebooks/diffusion_test.ipynb RENAMED
File without changes
egfx.ipynb β†’ notebooks/egfx.ipynb RENAMED
File without changes
guitar_generation_test.ipynb β†’ notebooks/guitar_generation_test.ipynb RENAMED
File without changes
setup.py CHANGED
@@ -42,6 +42,8 @@ setup(
42
  "ema_pytorch",
43
  "einops",
44
  "librosa",
 
 
45
  ],
46
  include_package_data=True,
47
  license="Apache License 2.0",
 
42
  "ema_pytorch",
43
  "einops",
44
  "librosa",
45
+ "hydra-core",
46
+ "auraloss",
47
  ],
48
  include_package_data=True,
49
  license="Apache License 2.0",
shell_vars.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ export DATASET_ROOT="/Users/matthewrice/Developer/remfx/data/egfx"
2
+ export WANDB_PROJECT="RemFX"
3
+ export WANDB_ENTITY="mattricesound"
train.py CHANGED
@@ -1,35 +1,50 @@
1
  from pytorch_lightning.loggers import WandbLogger
2
  import pytorch_lightning as pl
3
- import torch
4
  from torch.utils.data import DataLoader
5
  from datasets import GuitarFXDataset
6
  from models import DiffusionGenerationModel, OpenUnmixModel
 
 
 
7
 
 
8
 
9
- SAMPLE_RATE = 22050
10
- TRAIN_SPLIT = 0.8
11
 
 
 
 
 
 
12
 
13
- def main():
14
- wandb_logger = WandbLogger(project="RemFX", save_dir="./")
15
- trainer = pl.Trainer(logger=wandb_logger, max_epochs=100)
16
- guitfx = GuitarFXDataset(
17
- root="./data/egfx",
18
- sample_rate=SAMPLE_RATE,
19
- effect_type=["Phaser"],
20
- )
21
- train_size = int(TRAIN_SPLIT * len(guitfx))
22
- val_size = len(guitfx) - train_size
23
- train_dataset, val_dataset = torch.utils.data.random_split(
24
- guitfx, [train_size, val_size]
25
- )
26
- train = DataLoader(train_dataset, batch_size=2)
27
- val = DataLoader(val_dataset, batch_size=2)
28
 
29
- # model = DiffusionGenerationModel()
30
- model = OpenUnmixModel()
 
 
 
 
 
31
 
32
- trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  if __name__ == "__main__":
 
1
  from pytorch_lightning.loggers import WandbLogger
2
  import pytorch_lightning as pl
 
3
  from torch.utils.data import DataLoader
4
  from datasets import GuitarFXDataset
5
  from models import DiffusionGenerationModel, OpenUnmixModel
6
+ import hydra
7
+ from omegaconf import DictConfig
8
+ import utils
9
 
10
+ log = utils.get_logger(__name__)
11
 
 
 
12
 
13
+ @hydra.main(version_base=None, config_path=".", config_name="config.yaml")
14
+ def main(cfg: DictConfig):
15
+ # Apply seed for reproducibility
16
+ print(cfg)
17
+ pl.seed_everything(cfg.seed)
18
 
19
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
20
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
21
+
22
+ log.info(f"Instantiating model <{cfg.model._target_}>.")
23
+ model = hydra.utils.instantiate(cfg.model, _convert_="partial")
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Init all callbacks
26
+ callbacks = []
27
+ if "callbacks" in cfg:
28
+ for _, cb_conf in cfg["callbacks"].items():
29
+ if "_target_" in cb_conf:
30
+ log.info(f"Instantiating callback <{cb_conf._target_}>.")
31
+ callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
32
 
33
+ logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
34
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
35
+ trainer = hydra.utils.instantiate(
36
+ cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
37
+ )
38
+ log.info("Logging hyperparameters!")
39
+ utils.log_hyperparameters(
40
+ config=cfg,
41
+ model=model,
42
+ datamodule=datamodule,
43
+ trainer=trainer,
44
+ callbacks=callbacks,
45
+ logger=logger,
46
+ )
47
+ trainer.fit(model=model, datamodule=datamodule)
48
 
49
 
50
  if __name__ == "__main__":
utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+ import pytorch_lightning as pl
4
+ from omegaconf import DictConfig
5
+ from pytorch_lightning.utilities import rank_zero_only
6
+
7
+
8
+ def get_logger(name=__name__) -> logging.Logger:
9
+ """Initializes multi-GPU-friendly python command line logger."""
10
+
11
+ logger = logging.getLogger(name)
12
+
13
+ # this ensures all logging levels get marked with the rank zero decorator
14
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
15
+ for level in (
16
+ "debug",
17
+ "info",
18
+ "warning",
19
+ "error",
20
+ "exception",
21
+ "fatal",
22
+ "critical",
23
+ ):
24
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
25
+
26
+ return logger
27
+
28
+
29
+ log = get_logger(__name__)
30
+
31
+
32
+ @rank_zero_only
33
+ def log_hyperparameters(
34
+ config: DictConfig,
35
+ model: pl.LightningModule,
36
+ datamodule: pl.LightningDataModule,
37
+ trainer: pl.Trainer,
38
+ callbacks: List[pl.Callback],
39
+ logger: pl.loggers.LightningLoggerBase,
40
+ ) -> None:
41
+ """Controls which config parts are saved by Lightning loggers.
42
+ Additionaly saves:
43
+ - number of model parameters
44
+ """
45
+
46
+ if not trainer.logger:
47
+ return
48
+
49
+ hparams = {}
50
+
51
+ # choose which parts of hydra config will be saved to loggers
52
+ hparams["model"] = config["model"]
53
+
54
+ # save number of model parameters
55
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
56
+ hparams["model/params/trainable"] = sum(
57
+ p.numel() for p in model.parameters() if p.requires_grad
58
+ )
59
+ hparams["model/params/non_trainable"] = sum(
60
+ p.numel() for p in model.parameters() if not p.requires_grad
61
+ )
62
+
63
+ hparams["datamodule"] = config["datamodule"]
64
+ hparams["trainer"] = config["trainer"]
65
+
66
+ if "seed" in config:
67
+ hparams["seed"] = config["seed"]
68
+ if "callbacks" in config:
69
+ hparams["callbacks"] = config["callbacks"]
70
+
71
+ logger.experiment.config.update(hparams)