mattricesound commited on
Commit
f8fea2a
1 Parent(s): 1b72821

Fix DCUNet

Browse files
Files changed (1) hide show
  1. remfx/models.py +23 -0
remfx/models.py CHANGED
@@ -326,6 +326,29 @@ class DPTNetModel(nn.Module):
326
  return self.model(x.squeeze(1))
327
 
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  class TCNModel(nn.Module):
330
  def __init__(self, sample_rate, num_bins, **kwargs):
331
  super().__init__()
 
326
  return self.model(x.squeeze(1))
327
 
328
 
329
+ class DCUNetModel(nn.Module):
330
+ def __init__(self, sample_rate, num_bins, **kwargs):
331
+ super().__init__()
332
+ self.model = asteroid.models.DCUNet(**kwargs)
333
+ self.mrstftloss = MultiResolutionSTFTLoss(
334
+ n_bins=num_bins, sample_rate=sample_rate
335
+ )
336
+ self.l1loss = nn.L1Loss()
337
+
338
+ def forward(self, batch):
339
+ x, target = batch
340
+ output = self.model(x.squeeze(1)) # B x T
341
+ # Crop target to match output
342
+ if output.shape[-1] < target.shape[-1]:
343
+ target = causal_crop(target, output.shape[-1])
344
+ loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
345
+ return loss, output
346
+
347
+ def sample(self, x: Tensor) -> Tensor:
348
+ output = self.model(x.squeeze(1)) # B x T
349
+ return output
350
+
351
+
352
  class TCNModel(nn.Module):
353
  def __init__(self, sample_rate, num_bins, **kwargs):
354
  super().__init__()