mattricesound commited on
Commit
ff526b3
·
1 Parent(s): 1530829

Add metrics

Browse files
Files changed (1) hide show
  1. remfx/models.py +33 -21
remfx/models.py CHANGED
@@ -4,7 +4,9 @@ import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
- import auraloss
 
 
8
 
9
  from umx.openunmix.model import OpenUnmix, Separator
10
 
@@ -28,6 +30,13 @@ class RemFXModel(pl.LightningModule):
28
  self.lr_weight_decay = lr_weight_decay
29
  self.sample_rate = sample_rate
30
  self.model = network
 
 
 
 
 
 
 
31
 
32
  @property
33
  def device(self):
@@ -49,10 +58,23 @@ class RemFXModel(pl.LightningModule):
49
 
50
  def validation_step(self, batch, batch_idx):
51
  loss = self.common_step(batch, batch_idx, mode="valid")
 
52
 
53
  def common_step(self, batch, batch_idx, mode: str = "train"):
54
- loss = self.model(batch)
55
  self.log(f"{mode}_loss", loss)
 
 
 
 
 
 
 
 
 
 
 
 
56
  return loss
57
 
58
  def on_validation_epoch_start(self):
@@ -62,24 +84,13 @@ class RemFXModel(pl.LightningModule):
62
  if self.log_next:
63
  x, target, label = batch
64
  y = self.model.sample(x)
 
 
 
65
  log_wandb_audio_batch(
66
  logger=self.logger,
67
- id="sample",
68
- samples=x.cpu(),
69
- sampling_rate=self.sample_rate,
70
- caption=f"Epoch {self.current_epoch}",
71
- )
72
- log_wandb_audio_batch(
73
- logger=self.logger,
74
- id="prediction",
75
- samples=y.cpu(),
76
- sampling_rate=self.sample_rate,
77
- caption=f"Epoch {self.current_epoch}",
78
- )
79
- log_wandb_audio_batch(
80
- logger=self.logger,
81
- id="target",
82
- samples=target.cpu(),
83
  sampling_rate=self.sample_rate,
84
  caption=f"Epoch {self.current_epoch}",
85
  )
@@ -116,7 +127,7 @@ class OpenUnmixModel(torch.nn.Module):
116
  n_fft=self.n_fft,
117
  n_hop=self.hop_length,
118
  )
119
- self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
120
  n_bins=self.num_bins, sample_rate=self.sample_rate
121
  )
122
 
@@ -127,7 +138,7 @@ class OpenUnmixModel(torch.nn.Module):
127
  sep_out = self.separator(x).squeeze(1)
128
  loss = self.loss_fn(sep_out, target)
129
 
130
- return loss
131
 
132
  def sample(self, x: Tensor) -> Tensor:
133
  return self.separator(x).squeeze(1)
@@ -140,7 +151,8 @@ class DiffusionGenerationModel(nn.Module):
140
 
141
  def forward(self, batch):
142
  x, target, label = batch
143
- return self.model(x)
 
144
 
145
  def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
146
  noise = torch.randn(x.shape).to(x)
 
4
  from einops import rearrange
5
  import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
+ from auraloss.time import SISDRLoss
8
+ from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
+ from torch.nn import L1Loss
10
 
11
  from umx.openunmix.model import OpenUnmix, Separator
12
 
 
30
  self.lr_weight_decay = lr_weight_decay
31
  self.sample_rate = sample_rate
32
  self.model = network
33
+ self.metrics = torch.nn.ModuleDict(
34
+ {
35
+ "SISDR": SISDRLoss(),
36
+ "STFT": STFTLoss(),
37
+ "L1": L1Loss(),
38
+ }
39
+ )
40
 
41
  @property
42
  def device(self):
 
58
 
59
  def validation_step(self, batch, batch_idx):
60
  loss = self.common_step(batch, batch_idx, mode="valid")
61
+ return loss
62
 
63
  def common_step(self, batch, batch_idx, mode: str = "train"):
64
+ loss, output = self.model(batch)
65
  self.log(f"{mode}_loss", loss)
66
+ x, y, label = batch
67
+ # Metric logging
68
+ for metric in self.metrics:
69
+ self.log(
70
+ f"{mode}_{metric}",
71
+ self.metrics[metric](output, y),
72
+ on_step=False,
73
+ on_epoch=True,
74
+ logger=True,
75
+ prog_bar=True,
76
+ )
77
+
78
  return loss
79
 
80
  def on_validation_epoch_start(self):
 
84
  if self.log_next:
85
  x, target, label = batch
86
  y = self.model.sample(x)
87
+
88
+ # Concat samples together for easier viewing in dashboard
89
+ concat_samples = torch.cat([x, y, target], dim=-1)
90
  log_wandb_audio_batch(
91
  logger=self.logger,
92
+ id="prediction_sample_target",
93
+ samples=concat_samples.cpu(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  sampling_rate=self.sample_rate,
95
  caption=f"Epoch {self.current_epoch}",
96
  )
 
127
  n_fft=self.n_fft,
128
  n_hop=self.hop_length,
129
  )
130
+ self.loss_fn = MultiResolutionSTFTLoss(
131
  n_bins=self.num_bins, sample_rate=self.sample_rate
132
  )
133
 
 
138
  sep_out = self.separator(x).squeeze(1)
139
  loss = self.loss_fn(sep_out, target)
140
 
141
+ return loss, sep_out
142
 
143
  def sample(self, x: Tensor) -> Tensor:
144
  return self.separator(x).squeeze(1)
 
151
 
152
  def forward(self, batch):
153
  x, target, label = batch
154
+ sampled_out = self.model.sample(x)
155
+ return self.model(x), sampled_out
156
 
157
  def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
158
  noise = torch.randn(x.shape).to(x)