mattricesound commited on
Commit
106ab10
1 Parent(s): c04778c

Skip chunk if too quiet

Browse files
Files changed (2) hide show
  1. remfx/models.py +3 -0
  2. remfx/utils.py +3 -0
remfx/models.py CHANGED
@@ -84,6 +84,9 @@ class RemFX(pl.LightningModule):
84
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
85
 
86
  loss, output = self.model((x, y))
 
 
 
87
  self.log(f"{mode}_loss", loss)
88
  # Metric logging
89
  with torch.no_grad():
 
84
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
85
 
86
  loss, output = self.model((x, y))
87
+ # Crop target to match output
88
+ if output.shape[-1] < y.shape[-1]:
89
+ y = causal_crop(y, output.shape[-1])
90
  self.log(f"{mode}_loss", loss)
91
  # Metric logging
92
  with torch.no_grad():
remfx/utils.py CHANGED
@@ -158,6 +158,9 @@ def select_random_chunk(
158
  max_len = audio.shape[-1] - new_chunk_size
159
  random_start = torch.randint(0, max_len, (1,)).item()
160
  chunk = audio[:, random_start : random_start + new_chunk_size]
 
 
 
161
  resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
162
  return resampled_chunk
163
 
 
158
  max_len = audio.shape[-1] - new_chunk_size
159
  random_start = torch.randint(0, max_len, (1,)).item()
160
  chunk = audio[:, random_start : random_start + new_chunk_size]
161
+ # Skip if energy too low
162
+ if torch.mean(torch.abs(chunk)) < 1e-6:
163
+ return None
164
  resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
165
  return resampled_chunk
166