mattricesound commited on
Commit
0e3a05d
1 Parent(s): 0fbacb2

Add target cropping if outputs are different length

Browse files
Files changed (2) hide show
  1. remfx/models.py +7 -20
  2. remfx/tcn.py +0 -2
remfx/models.py CHANGED
@@ -13,6 +13,7 @@ from remfx.utils import FADLoss, spectrogram
13
  from remfx.dptnet import DPTNet_base
14
  from remfx.dcunet import RefineSpectrogramUnet
15
  from remfx.tcn import TCN
 
16
 
17
 
18
  class RemFX(pl.LightningModule):
@@ -223,21 +224,14 @@ class DCUNetModel(nn.Module):
223
  def forward(self, batch):
224
  x, target = batch
225
  output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
226
- # Pad or crop to match target
227
- if output.shape[-1] > target.shape[-1]:
228
- output = output[:, : target.shape[-1]]
229
- elif output.shape[-1] < target.shape[-1]:
230
- output = F.pad(output, (0, target.shape[-1] - output.shape[-1]))
231
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
232
  return loss, output
233
 
234
  def sample(self, x: Tensor) -> Tensor:
235
  output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
236
- # Pad or crop to match target
237
- if output.shape[-1] > x.shape[-1]:
238
- output = output[:, : x.shape[-1]]
239
- elif output.shape[-1] < x.shape[-1]:
240
- output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
241
  return output
242
 
243
 
@@ -253,21 +247,14 @@ class TCNModel(nn.Module):
253
  def forward(self, batch):
254
  x, target = batch
255
  output = self.model(x) # B x 1 x T
256
- # Pad or crop to match target
257
- if output.shape[-1] > x.shape[-1]:
258
- output = output[:, : x.shape[-1]]
259
- elif output.shape[-1] < x.shape[-1]:
260
- output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
261
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
262
  return loss, output
263
 
264
  def sample(self, x: Tensor) -> Tensor:
265
  output = self.model(x) # B x 1 x T
266
- # Pad or crop to match target
267
- if output.shape[-1] > x.shape[-1]:
268
- output = output[:, : x.shape[-1]]
269
- elif output.shape[-1] < x.shape[-1]:
270
- output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
271
  return output
272
 
273
 
 
13
  from remfx.dptnet import DPTNet_base
14
  from remfx.dcunet import RefineSpectrogramUnet
15
  from remfx.tcn import TCN
16
+ from remfx.utils import causal_crop
17
 
18
 
19
  class RemFX(pl.LightningModule):
 
224
  def forward(self, batch):
225
  x, target = batch
226
  output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
227
+ # Crop target to match output
228
+ if output.shape[-1] < target.shape[-1]:
229
+ target = causal_crop(target, output.shape[-1])
 
 
230
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
231
  return loss, output
232
 
233
  def sample(self, x: Tensor) -> Tensor:
234
  output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
 
 
 
 
 
235
  return output
236
 
237
 
 
247
  def forward(self, batch):
248
  x, target = batch
249
  output = self.model(x) # B x 1 x T
250
+ # Crop target to match output
251
+ if output.shape[-1] < target.shape[-1]:
252
+ target = causal_crop(target, output.shape[-1])
 
 
253
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
254
  return loss, output
255
 
256
  def sample(self, x: Tensor) -> Tensor:
257
  output = self.model(x) # B x 1 x T
 
 
 
 
 
258
  return output
259
 
260
 
remfx/tcn.py CHANGED
@@ -25,8 +25,6 @@ class TCNBlock(nn.Module):
25
  self.stride = stride
26
 
27
  self.crop_fn = crop_fn
28
- # Assumes stride of 1
29
- padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
30
  self.conv1 = nn.Conv1d(
31
  in_ch,
32
  out_ch,
 
25
  self.stride = stride
26
 
27
  self.crop_fn = crop_fn
 
 
28
  self.conv1 = nn.Conv1d(
29
  in_ch,
30
  out_ch,